RL Trainers#
RL Trainers are what you’re training with trlX. Currently, we support PPO and ILQL.
Note that new trainers must be registered with trlx.trainer.register_trainer.
General
- class trlx.trainer.BaseRLTrainer(config, reward_fn=None, metric_fn=None, logit_mask=None, stop_sequences=None, train_mode=False)[source]#
- Parameters:
config (
TRLConfig) –
- abstract learn(log_fn=None, save_fn=None, eval_fn=None)[source]#
Use experiences in RolloutStore to learn
- Parameters:
log_fn (Callable[Dict[str, any]]) – Optional function that is called when logging and passed a dict of logging relevant values
save_fn (Callable[Dict[str, any]]) – Optional function to call after saving. Is passed the components.
eval_fn (Callable[BaseRLTrainer]) – Optional function to call during evaluation. Eval doesn’t do anything without this.
- class trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer(config, **kwargs)[source]#
RL model trainer with an accelerate based backend
- decode(prompts, samples, prompt_sizes=None)[source]#
Decode tensor generations into lists of strings (samples: List[str], prompts: List[str], outputs: List[str])
- evaluate()[source]#
Samples model on eval_prompts, logs stats with reward_fn or metric_fn if provided
- generate(input_ids, attention_mask=None, **kwargs)[source]#
Wraps hf’s generate adding some specific method’s defaults
- generate_eval(input_ids, attention_mask=None, **kwargs)[source]#
Wraps hf’s generate adding some specific method’s defaults
- abstract get_arch(config)[source]#
Returns a specific wrapper of the decoder architecture
- Parameters:
config (
TRLConfig) –
- learn()[source]#
Samples batches from self.store, updates model and periodically evaluates it on self.eval_dataloader
- save_pretrained(directory=None, **kwargs)[source]#
Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for later use.
- Parameters:
directory (str, optional) – The directory to save the trainer files to. NOTE: If not specified, the model will be saved to a directory named hf_model in the checkpoint directory as specified by the Trainer’s config.
**kwargs – Additional keyword arguments passed to the underlying Hugging Face model’s save_pretrained method.
PPO
- class trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer(config, **kwargs)[source]#
PPO Accelerate Trainer
- Parameters:
config (
TRLConfig) –
- add_prompt_pipeline(pipeline)[source]#
Add a prompt pipeline dataloader to a trainer instance for the make_experience stage
- Parameters:
pipeline (
PromptPipeline) –
- loss(batch)[source]#
Forward pass & loss
- Parameters:
batch (
PPORLBatch) – Previous batch of episodes
ILQL
- class trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer(config, **kwargs)[source]#
- Parameters:
config (
TRLConfig) –
- make_experience(samples, rewards, max_length=2048)[source]#
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer