Configs#

Training a model in TRL will require you to set several configs: ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for the specific method being used (i.e. ILQL or PPO)

General

class trlx.data.configs.TRLConfig(method, model, optimizer, scheduler, tokenizer, train)[source]#

Top level config for trlX. Loads configs and can be converted to dictionary.

Parameters:
  • method (MethodConfig) –

  • model (ModelConfig) –

  • optimizer (OptimizerConfig) –

  • scheduler (SchedulerConfig) –

  • tokenizer (TokenizerConfig) –

  • train (TrainConfig) –

evolve(**kwargs)[source]#

Evolve TRLConfig with new parameters. Can update nested parameters. >>> config = trlx.data.default_configs.default_ilql_config() >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) >>> config.method.gamma 0.99

Return type:

TRLConfig

classmethod from_dict(config)[source]#

Convert dictionary to TRLConfig.

Parameters:

config (Dict) –

classmethod load_yaml(yml_fp)[source]#

Load yaml file as TRLConfig.

Parameters:

yml_fp (str) – Path to yaml file

to_dict()[source]#

Convert TRLConfig to dictionary.

class trlx.data.configs.ModelConfig(model_path, model_arch_type='causal', num_layers_unfrozen=-1, delta_kwargs=None)[source]#

Config for a model.

Parameters:
  • model_path (str) – Path or name of the model (local or on huggingface hub)

  • model_arch_type (str) – Type of model architecture. Either “causal” or “seq2seq”

  • num_layers_unfrozen (int) – Number of layers to unfreeze for fine-tuning. -1 means all layers are unfrozen.

  • delta_kwargs (Optional[Dict[str, Any]]) –

    Keyword arguments for instantiating OpenDelta models for delta-tuning. Follow the OpenDelta.AutoDeltaConfig specification, e.g. for LoRA style tuning, set the delta_type to lora and include the model specific hyper-parameters (e.g. lora_r)

    {“delta_type”: “lora”, “modified_modules”: “all”, “lora_r”: 8, “lora_alpha”: 16, “lora_dropout”: 0.0}

    or in YAML format:
    delta_kwargs:

    delta_type: lora modified_modules: “all” lora_r: 8 lora_alpha: 16 lora_dropout: 0.0

    See: https://opendelta.readthedocs.io/en/latest/modules/auto_delta.html#opendelta.auto_delta.AutoDeltaConfig

class trlx.data.configs.TrainConfig(total_steps, seq_length, epochs, batch_size, checkpoint_interval, eval_interval, pipeline, trainer, trainer_kwargs=<factory>, project_name='trlx', entity_name=None, group_name=None, checkpoint_dir='ckpts', rollout_logging_dir=None, save_best=True, tracker='wandb', logging_dir=None, seed=1000)[source]#

Config for train job on model.

Parameters:
  • total_steps (int) – Total number of training steps

  • seq_length (int) – Number of tokens to use as context (max length for tokenizer)

  • epochs (int) – Total number of passes through data

  • batch_size (int) – Batch size for training

  • tracker (str) – Tracker to use for logging. Default: “wandb”

  • checkpoint_interval (int) – Save model every checkpoint_interval steps. Each checkpoint is stored in a sub-directory of the TrainConfig.checkpoint_dir directory in the format checkpoint_dir/checkpoint_{step}.

  • eval_interval (int) – Evaluate model every eval_interval steps

  • pipeline (str) – Pipeline to use for training. One of the registered pipelines present in trlx.pipeline

  • trainer (Dict[str, Any]) – Trainer to use for training. One of the registered trainers present in trlx.trainer

  • trainer_kwargs (Dict[str, Any]) – Extra keyword arguments for the trainer

  • project_name (str) – Project name for wandb

  • entity_name (str) – Entity name for wandb

  • group_name (str) – Group name for wandb (used for grouping runs)

  • checkpoint_dir (str) – Directory to save checkpoints

  • rollout_logging_dir (Optional[str]) – Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOTrainer.

  • save_best (bool) – Save best model based on mean reward

  • seed (int) – Random seed

  • logging_dir (Optional[str]) –

class trlx.data.method_configs.MethodConfig(name)[source]#

Config for a certain RL method.

Parameters:

name (str) – Name of the method

PPO

ILQL