RL Trainers#

RL Trainers are what you’re training with trlX. Currently, we support PPO and ILQL. Note that new trainers must be registered with trlx.trainer.register_trainer.

General

class trlx.trainer.BaseRLTrainer(config, reward_fn=None, metric_fn=None, logit_mask=None, stop_sequences=None, train_mode=False)[source]#
Parameters:

config (TRLConfig) –

add_eval_pipeline(eval_pipeline)[source]#

Adds pipeline for validation prompts

abstract learn(log_fn=None, save_fn=None, eval_fn=None)[source]#

Use experiences in RolloutStore to learn

Parameters:
  • log_fn (Callable[Dict[str, any]]) – Optional function that is called when logging and passed a dict of logging relevant values

  • save_fn (Callable[Dict[str, any]]) – Optional function to call after saving. Is passed the components.

  • eval_fn (Callable[BaseRLTrainer]) – Optional function to call during evaluation. Eval doesn’t do anything without this.

abstract load(directory=None)[source]#

Loads a checkpoint created from save

abstract sample(prompts, length, n_samples)[source]#

Sample from the language. Takes prompts and maximum length to generate.

Parameters:
  • prompts (Iterable[str]) – List of prompts to tokenize and use as context

  • length (int) – How many new tokens to generate for each prompt

  • n_samples (int) – Default behavior is to take number of prompts as this

Return type:

Iterable[str]

abstract save(directory=None)[source]#

Creates a checkpoint of training states

Parameters:

directory (Optional[str]) –

class trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer(config, **kwargs)[source]#

RL model trainer with an accelerate based backend

add_eval_pipeline(eval_pipeline)[source]#

Adds pipeline from with validation prompts

decode(prompts, samples, prompt_sizes=None)[source]#

Decode tensor generations into lists of strings (samples: List[str], prompts: List[str], outputs: List[str])

Parameters:
  • prompts (List[LongTensor]) –

  • samples (List[LongTensor]) –

  • prompt_sizes (Optional[LongTensor]) –

Return type:

Tuple[List[str], List[str], List[str]]

evaluate()[source]#

Samples model on eval_prompts, logs stats with reward_fn or metric_fn if provided

generate(input_ids, attention_mask=None, **kwargs)[source]#

Wraps hf’s generate adding some specific method’s defaults

generate_eval(input_ids, attention_mask=None, **kwargs)[source]#

Wraps hf’s generate adding some specific method’s defaults

abstract get_arch(config)[source]#

Returns a specific wrapper of the decoder architecture

Parameters:

config (TRLConfig) –

learn()[source]#

Samples batches from self.store, updates model and periodically evaluates it on self.eval_dataloader

load(directory=None, **kwargs)[source]#

Load checkpoint of optimizer, scheduler and a model

Parameters:

directory (Optional[str]) –

abstract loss(batch)[source]#

Compute loss on a batch from store and return some statistics

Return type:

Tuple[float, Dict]

abstract post_backward_callback()[source]#

Do something after model update

abstract post_epoch_callback()[source]#

Do something after exhausting/single pass over self.store

save(directory=None, **kwargs)[source]#

Creates a checkpoint of the optimizer, scheduler and model

Parameters:

directory (Optional[str]) –

save_pretrained(directory=None, **kwargs)[source]#

Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for later use.

Parameters:
  • directory (str, optional) – The directory to save the trainer files to. NOTE: If not specified, the model will be saved to a directory named hf_model in the checkpoint directory as specified by the Trainer’s config.

  • **kwargs – Additional keyword arguments passed to the underlying Hugging Face model’s save_pretrained method.

setup_model()[source]#

Returns a model derived from an instance’s TRLConfig

setup_optimizer()[source]#

Returns an optimizer derived from an instance’s TRLConfig

setup_scheduler()[source]#

Returns a learning rate scheduler derived from an instance’s TRLConfig

PPO

class trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer(config, **kwargs)[source]#

PPO Accelerate Trainer

Parameters:

config (TRLConfig) –

add_prompt_pipeline(pipeline)[source]#

Add a prompt pipeline dataloader to a trainer instance for the make_experience stage

Parameters:

pipeline (PromptPipeline) –

get_arch(config)[source]#

Get the model

Parameters:

config (TRLConfig) –

loss(batch)[source]#

Forward pass & loss

Parameters:

batch (PPORLBatch) – Previous batch of episodes

make_experience(num_rollouts=1024, iter_count=0)[source]#

Make experiences

Takes chunk_size number of prompts from prompt_iterator, samples from the model and then computes the KL against a reference model. Finally it then appends PPOElements to trainer’s store.

Parameters:
  • num_rollouts (int) – Number of rollouts to generate

  • iter_count (int) – Total number of updates run (i.e. number of updates run for all batches & epochs)

post_backward_callback()[source]#

Do something after model update

post_epoch_callback()[source]#

Post epoch callback

Clears the store and creates num_rollouts new episodes.

ILQL

class trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer(config, **kwargs)[source]#
Parameters:

config (TRLConfig) –

get_arch(config)[source]#

Returns a specific wrapper of the decoder architecture

loss(batch)[source]#

Compute loss on a batch from store and return some statistics

Parameters:

batch (Union[ILQLBatch, ILQLSeq2SeqBatch]) –

make_experience(samples, rewards, max_length=2048)[source]#

Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer

make_experience_seq2seq(samples, rewards, max_length=2048)[source]#

Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer

post_backward_callback()[source]#

Do something after model update