from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set
import yaml
from trlx.data.method_configs import MethodConfig, get_method
def merge(base: Dict, update: Dict, updated: Set) -> Dict:
"Recursively updates a nested dictionary with new values"
for k, v in base.items():
if k in update and isinstance(v, dict):
base[k] = merge(v, update[k], updated)
updated.add(k)
elif k in update:
base[k] = update[k]
updated.add(k)
return base
def _merge_dicts(base: Dict, update: Dict) -> Dict:
"Merge two dictionaries recursively, returning a new dictionary."
base = deepcopy(base)
for k, v in update.items():
if isinstance(v, dict):
base[k] = _merge_dicts(base.get(k, {}), v)
else:
base[k] = v
return base
[docs]@dataclass
class ModelConfig:
"""
Config for a model.
:param model_path: Path or name of the model (local or on huggingface hub)
:type model_path: str
:param model_arch_type: Type of model architecture. Either "causal" or "seq2seq"
:type model_arch_type: str
:param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning.
-1 means all layers are unfrozen.
:type num_layers_unfrozen: int
:param delta_kwargs: Keyword arguments for instantiating OpenDelta models for delta-tuning.
Follow the `OpenDelta.AutoDeltaConfig` specification, e.g. for LoRA style tuning, set
the `delta_type` to `lora` and include the model specific hyper-parameters (e.g. `lora_r`)
{"delta_type": "lora", "modified_modules": "all", "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.0}
or in YAML format:
delta_kwargs:
delta_type: lora
modified_modules: "all"
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0
See: https://opendelta.readthedocs.io/en/latest/modules/auto_delta.html#opendelta.auto_delta.AutoDeltaConfig
:type delta_kwargs: Optional[Dict[str, Any]]
"""
model_path: str
model_arch_type: str = "causal"
num_layers_unfrozen: int = -1
delta_kwargs: Optional[Dict[str, Any]] = None
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
@dataclass
class TokenizerConfig:
"""
Config for a model.
:param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub)
:type tokenizer_path: str
:param padding_side: Padding side
:type padding_path: str
:param truncation_side: Truncation side
:type truncation_side: str
"""
tokenizer_path: str
padding_side: str = "left"
truncation_side: str = "right"
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
@dataclass
class OptimizerConfig:
"""
Config for an optimizer.
:param name: Name of the optimizer
:type name: str
:param kwargs: Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay)
:type kwargs: Dict[str, Any]
"""
name: str
kwargs: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
@dataclass
class SchedulerConfig:
"""
Config for a learning rate scheduler.
:param name: Name of the scheduler
:type name: str
:param kwargs: Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max)
:type kwargs: Dict[str, Any]
"""
name: str
kwargs: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
[docs]@dataclass
class TrainConfig:
"""
Config for train job on model.
:param total_steps: Total number of training steps
:type total_steps: int
:param seq_length: Number of tokens to use as context (max length for tokenizer)
:type seq_length: int
:param epochs: Total number of passes through data
:type epochs: int
:param batch_size: Batch size for training
:type batch_size: int
:param tracker: Tracker to use for logging. Default: "wandb"
:type tracker: str
:param checkpoint_interval: Save model every checkpoint_interval steps.
Each checkpoint is stored in a sub-directory of the `TrainConfig.checkpoint_dir`
directory in the format `checkpoint_dir/checkpoint_{step}`.
:type checkpoint_interval: int
:param eval_interval: Evaluate model every eval_interval steps
:type eval_interval: int
:param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline
:type pipeline: str
:param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer
:type trainer: str
:param trainer_kwargs: Extra keyword arguments for the trainer
:type trainer: Dict[str, Any]
:param project_name: Project name for wandb
:type project_name: str
:param entity_name: Entity name for wandb
:type entity_name: str
:param group_name: Group name for wandb (used for grouping runs)
:type group_name: str
:param checkpoint_dir: Directory to save checkpoints
:type checkpoint_dir: str
:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation.
Only used by AcceleratePPOTrainer.
:type rollout_logging_dir: Optional[str]
:param save_best: Save best model based on mean reward
:type save_best: bool
:param seed: Random seed
:type seed: int
"""
total_steps: int
seq_length: int
epochs: int
batch_size: int
checkpoint_interval: int
eval_interval: int
pipeline: str # One of the pipelines in framework.pipeline
trainer: str # One of the trainers
trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer
project_name: str = "trlx"
entity_name: Optional[str] = None
group_name: Optional[str] = None
checkpoint_dir: str = "ckpts"
rollout_logging_dir: Optional[str] = None
save_best: bool = True
tracker: Optional[str] = "wandb"
logging_dir: Optional[str] = None
seed: int = 1000
@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
[docs]@dataclass
class TRLConfig:
"""
Top level config for trlX. Loads configs and can be converted to dictionary.
"""
method: MethodConfig
model: ModelConfig
optimizer: OptimizerConfig
scheduler: SchedulerConfig
tokenizer: TokenizerConfig
train: TrainConfig
[docs] @classmethod
def load_yaml(cls, yml_fp: str):
"""
Load yaml file as TRLConfig.
:param yml_fp: Path to yaml file
:type yml_fp: str
"""
with open(yml_fp, mode="r") as file:
config = yaml.safe_load(file)
return cls.from_dict(config)
[docs] def to_dict(self):
"""
Convert TRLConfig to dictionary.
"""
data = {
"method": self.method.__dict__,
"model": self.model.__dict__,
"optimizer": self.optimizer.__dict__,
"scheduler": self.scheduler.__dict__,
"tokenizer": self.tokenizer.__dict__,
"train": self.train.__dict__,
}
return data
[docs] def evolve(self, **kwargs) -> "TRLConfig":
"""
Evolve TRLConfig with new parameters. Can update nested parameters.
>>> config = trlx.data.default_configs.default_ilql_config()
>>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100))
>>> config.method.gamma
0.99
"""
return TRLConfig.from_dict(_merge_dicts(self.to_dict(), kwargs))
[docs] @classmethod
def from_dict(cls, config: Dict):
"""
Convert dictionary to TRLConfig.
"""
return cls(
method=get_method(config["method"]["name"]).from_dict(config["method"]),
model=ModelConfig.from_dict(config["model"]),
tokenizer=TokenizerConfig.from_dict(config["tokenizer"]),
optimizer=OptimizerConfig.from_dict(config["optimizer"]),
scheduler=SchedulerConfig.from_dict(config["scheduler"]),
train=TrainConfig.from_dict(config["train"]),
)
@classmethod
def update(cls, baseconfig: Dict, config: Dict):
if not isinstance(baseconfig, Dict):
baseconfig = baseconfig.to_dict()
updates = set()
merged = merge(baseconfig, config, updates)
for param in config:
if param not in updates:
raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)")
return cls.from_dict(merged)
def __str__(self):
"""Returns a human-readable string representation of the config."""
import json
return json.dumps(self.to_dict(), indent=4)