import os
from typing import Union, cast
import numpy as np
import torch
import transformers
from rich.console import Console
from rich.table import Table
import trlx.utils.logging as logging
from trlx.data.configs import TRLConfig
from trlx.data.ilql_types import ILQLBatch, ILQLSeq2SeqBatch
from trlx.models.modeling_ilql import (
AutoModelForCausalLMWithILQLHeads,
AutoModelForSeq2SeqLMWithILQLHeads,
ILQLConfig,
)
from trlx.pipeline.offline_pipeline import (
ILQLRolloutStorage,
ILQLSeq2SeqRolloutStorage,
tokenize_dialogue,
)
from trlx.trainer import register_trainer
from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer
from trlx.utils import to_device
logger = logging.get_logger(__name__)
[docs]@register_trainer
class AccelerateILQLTrainer(AccelerateRLTrainer):
def __init__(self, config: TRLConfig, **kwargs):
super().__init__(config, **kwargs)
if not isinstance(config.method, ILQLConfig):
raise ValueError("config.method must be ILQLConfig")
self.ilql: ILQLConfig = cast(ILQLConfig, config.method)
self.generate_kwargs = dict(
config.method.gen_kwargs,
max_length=self.max_length,
logit_mask=self.logit_mask,
eos_token_id=self.tokenizer.eos_token_id if self.tokenizer else 0,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer else 0,
)
[docs] def get_arch(self, config):
if config.model.model_arch_type == "seq2seq":
from_fn = AutoModelForSeq2SeqLMWithILQLHeads.from_pretrained
if issubclass(type(config.model.model_path), transformers.PretrainedConfig):
from_fn = AutoModelForSeq2SeqLMWithILQLHeads.from_config
else:
from_fn = AutoModelForCausalLMWithILQLHeads.from_pretrained
if issubclass(type(config.model.model_path), transformers.PretrainedConfig):
from_fn = AutoModelForCausalLMWithILQLHeads.from_config
return from_fn(
config.model.model_path,
two_qs=config.method.two_qs,
alpha=config.method.alpha,
)
[docs] def post_backward_callback(self):
if self.iter_count % self.config.method.steps_for_target_q_sync == 0:
self.accelerator.unwrap_model(self.model).sync_target_q_heads()
[docs] def loss(self, batch: Union[ILQLBatch, ILQLSeq2SeqBatch]):
batch = to_device(batch, self.accelerator.device)
if self.config.model.model_arch_type == "seq2seq":
logits, qs, target_qs, vs, _, _ = self.model(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
actions_ixs=batch.actions_ixs,
states_ixs=batch.states_ixs,
decoder_input_ids=batch.decoder_input_ids,
)
else:
logits, qs, target_qs, vs, _ = self.model(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
actions_ixs=batch.actions_ixs,
states_ixs=batch.states_ixs,
)
return self.ilql.loss((logits, (qs, target_qs, vs)), batch)
def prepare_learning(self):
train_dataloader = self.store.create_loader(self.config.train.batch_size)
eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size)
(
self.model,
self.opt,
self.train_dataloader,
self.eval_dataloader,
) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader)
self.n_updates_per_batch = 1
self.total_steps = self.config.train.epochs * len(self.train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)
[docs] def make_experience_seq2seq(self, samples, rewards, max_length=2048):
"""
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer
"""
logger.info("Collecting rollouts")
if self.tokenizer:
samples = [tokenize_dialogue(s, self.tokenizer, max_length) for s in samples]
all_input_ids = []
all_output_ids = []
all_actions_ixs = []
all_states_ixs = []
all_dones = []
for sample in samples:
all_input_ids.append(torch.tensor(sample[0]))
all_output_ids.append(torch.tensor(sample[1]))
isoutput = False
actions_ixs = []
length = 0
for phrase in sample:
if isoutput:
length = len(phrase)
actions_ixs.append(torch.arange(0, length - 1))
isoutput = not isoutput
states_ixs = torch.hstack((*actions_ixs, torch.tensor(length - 1)))
all_dones.append(torch.tensor([1] * (len(states_ixs) - 1) + [0], dtype=int))
all_actions_ixs.append(torch.hstack(actions_ixs))
all_states_ixs.append(states_ixs)
if self.tokenizer and os.environ.get("RANK", "0") == "0":
logger.info("Logging sample example")
prompt = self.tokenizer.decode(all_input_ids[0])
response = self.tokenizer.decode(all_output_ids[0])
columns = ["Prompt", "Response", "Reward"]
table = Table(*columns, title="Sample Example", show_lines=True)
table.add_row(prompt, response, str(rewards[0]))
Console().print(table)
sample_lengths = np.array(list(map(len, all_input_ids))) + np.array(list(map(len, all_output_ids)))
output_lengths = np.array(list(map(len, all_output_ids)))
prompt_lengths = sample_lengths - output_lengths
returns = torch.tensor(rewards, dtype=float)
if os.environ.get("RANK", "0") == "0":
logger.info("Logging experience string statistics")
columns = ["Prompt Length", "Output Length", "Sample Length"]
table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True)
row = []
for lengths in [prompt_lengths, output_lengths, sample_lengths]:
row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]")
table.add_row(*row)
Console().print(table)
returns = (returns - returns.mean()) / (returns.std() + torch.finfo(returns.dtype).eps)
rewards = [torch.zeros(len(x)) for x in all_actions_ixs]
for rs, ret in zip(rewards, returns):
rs[-1] = ret
attention_mask = [torch.ones(len(x), dtype=int) for x in all_input_ids]
self.store = ILQLSeq2SeqRolloutStorage(
all_input_ids,
attention_mask,
all_output_ids,
rewards,
all_states_ixs,
all_actions_ixs,
all_dones,
)
[docs] def make_experience(self, samples, rewards, max_length=2048):
"""
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer
"""
if self.config.model.model_arch_type == "seq2seq":
return self.make_experience_seq2seq(samples, rewards, max_length)
logger.info("Collecting rollouts")
if self.tokenizer:
samples = [tokenize_dialogue(s, self.tokenizer, max_length) for s in samples]
all_input_ids = []
all_actions_ixs = []
all_states_ixs = []
all_dones = []
for sample in samples:
length = 0
all_input_ids.append(torch.tensor(sum(sample, [])))
isoutput = False
actions_ixs = []
for phrase in sample:
if isoutput:
actions_ixs.append(torch.arange(length - 1, length + len(phrase) - 1))
length += len(phrase)
isoutput = not isoutput
states_ixs = torch.hstack((*actions_ixs, torch.tensor(length - 1)))
all_dones.append(torch.tensor([1] * (len(states_ixs) - 1) + [0], dtype=int))
all_actions_ixs.append(torch.hstack(actions_ixs))
all_states_ixs.append(states_ixs)
if self.tokenizer and os.environ.get("RANK", "0") == "0":
logger.info("Logging sample example")
prompt = self.tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]])
response = self.tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :])
columns = ["Prompt", "Response", "Reward"]
table = Table(*columns, title="Sample Example", show_lines=True)
table.add_row(prompt, response, str(rewards[0]))
Console().print(table)
sample_lengths = np.array(list(map(len, all_input_ids)))
output_lengths = np.array(list(map(len, all_actions_ixs)))
prompt_lengths = sample_lengths - output_lengths
returns = torch.tensor(rewards, dtype=float)
if os.environ.get("RANK", "0") == "0":
logger.info("Logging experience string statistics")
columns = ["Prompt Length", "Output Length", "Sample Length"]
table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True)
row = []
for lengths in [prompt_lengths, output_lengths, sample_lengths]:
row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]")
table.add_row(*row)
Console().print(table)
returns = (returns - returns.mean()) / (returns.std() + 1e-30)
rewards = [torch.zeros(len(x)) for x in all_actions_ixs]
for rs, ret in zip(rewards, returns):
rs[-1] = ret
attention_mask = [torch.ones(len(x), dtype=int) for x in all_input_ids]
self.store = ILQLRolloutStorage(
all_input_ids,
attention_mask,
rewards,
all_states_ixs,
all_actions_ixs,
all_dones,
)