Configs#
Training a model in TRL will require you to set several configs: ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for the specific method being used (i.e. ILQL or PPO)
General
- class trlx.data.configs.TRLConfig(method, model, optimizer, scheduler, tokenizer, train)[source]#
Top level config for trlX. Loads configs and can be converted to dictionary.
- Parameters:
method (
MethodConfig) –model (
ModelConfig) –optimizer (
OptimizerConfig) –scheduler (
SchedulerConfig) –tokenizer (
TokenizerConfig) –train (
TrainConfig) –
- evolve(**kwargs)[source]#
Evolve TRLConfig with new parameters. Can update nested parameters. >>> config = trlx.data.default_configs.default_ilql_config() >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) >>> config.method.gamma 0.99
- Return type:
- class trlx.data.configs.ModelConfig(model_path, model_arch_type='causal', num_layers_unfrozen=-1, delta_kwargs=None)[source]#
Config for a model.
- Parameters:
model_path (str) – Path or name of the model (local or on huggingface hub)
model_arch_type (str) – Type of model architecture. Either “causal” or “seq2seq”
num_layers_unfrozen (int) – Number of layers to unfreeze for fine-tuning. -1 means all layers are unfrozen.
delta_kwargs (Optional[Dict[str, Any]]) –
Keyword arguments for instantiating OpenDelta models for delta-tuning. Follow the OpenDelta.AutoDeltaConfig specification, e.g. for LoRA style tuning, set the delta_type to lora and include the model specific hyper-parameters (e.g. lora_r)
{“delta_type”: “lora”, “modified_modules”: “all”, “lora_r”: 8, “lora_alpha”: 16, “lora_dropout”: 0.0}
- or in YAML format:
- delta_kwargs:
delta_type: lora modified_modules: “all” lora_r: 8 lora_alpha: 16 lora_dropout: 0.0
- class trlx.data.configs.TrainConfig(total_steps, seq_length, epochs, batch_size, checkpoint_interval, eval_interval, pipeline, trainer, trainer_kwargs=<factory>, project_name='trlx', entity_name=None, group_name=None, checkpoint_dir='ckpts', rollout_logging_dir=None, save_best=True, tracker='wandb', logging_dir=None, seed=1000)[source]#
Config for train job on model.
- Parameters:
total_steps (int) – Total number of training steps
seq_length (int) – Number of tokens to use as context (max length for tokenizer)
epochs (int) – Total number of passes through data
batch_size (int) – Batch size for training
tracker (str) – Tracker to use for logging. Default: “wandb”
checkpoint_interval (int) – Save model every checkpoint_interval steps. Each checkpoint is stored in a sub-directory of the TrainConfig.checkpoint_dir directory in the format checkpoint_dir/checkpoint_{step}.
eval_interval (int) – Evaluate model every eval_interval steps
pipeline (str) – Pipeline to use for training. One of the registered pipelines present in trlx.pipeline
trainer (Dict[str, Any]) – Trainer to use for training. One of the registered trainers present in trlx.trainer
trainer_kwargs (
Dict[str,Any]) – Extra keyword arguments for the trainerproject_name (str) – Project name for wandb
entity_name (str) – Entity name for wandb
group_name (str) – Group name for wandb (used for grouping runs)
checkpoint_dir (str) – Directory to save checkpoints
rollout_logging_dir (Optional[str]) – Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOTrainer.
save_best (bool) – Save best model based on mean reward
seed (int) – Random seed
- class trlx.data.method_configs.MethodConfig(name)[source]#
Config for a certain RL method.
- Parameters:
name (str) – Name of the method
PPO
ILQL