Source code for trlx.trainer

import sys
from abc import abstractmethod
from typing import Any, Callable, Dict, Iterable, Optional

from trlx.data.configs import TRLConfig
from trlx.pipeline import BaseRolloutStore

# specifies a dictionary of architectures
_TRAINERS: Dict[str, Any] = {}  # registry


def register_trainer(name):
    """Decorator used to register a trainer
    Args:
        name: Name of the trainer type to register
    """

    def register_class(cls, name):
        _TRAINERS[name] = cls
        setattr(sys.modules[__name__], name, cls)
        return cls

    if isinstance(name, str):
        name = name.lower()
        return lambda c: register_class(c, name)

    cls = name
    name = cls.__name__
    register_class(cls, name.lower())

    return cls


[docs]@register_trainer class BaseRLTrainer: def __init__( self, config: TRLConfig, reward_fn=None, metric_fn=None, logit_mask=None, stop_sequences=None, train_mode=False, ): self.store: BaseRolloutStore = None self.config = config self.reward_fn = reward_fn self.metric_fn = metric_fn self.train_mode = train_mode self.logit_mask = logit_mask self.stop_sequences = stop_sequences def push_to_store(self, data): self.store.push(data)
[docs] def add_eval_pipeline(self, eval_pipeline): """Adds pipeline for validation prompts""" self.eval_pipeline = eval_pipeline
[docs] @abstractmethod def sample(self, prompts: Iterable[str], length: int, n_samples: int) -> Iterable[str]: """ Sample from the language. Takes prompts and maximum length to generate. :param prompts: List of prompts to tokenize and use as context :param length: How many new tokens to generate for each prompt :type length: int :param n_samples: Default behavior is to take number of prompts as this """ pass
[docs] @abstractmethod def learn( self, log_fn: Callable = None, save_fn: Callable = None, eval_fn: Callable = None, ): """ Use experiences in RolloutStore to learn :param log_fn: Optional function that is called when logging and passed a dict of logging relevant values :type log_fn: Callable[Dict[str, any]] :param save_fn: Optional function to call after saving. Is passed the components. :type save_fn: Callable[Dict[str, any]] :param eval_fn: Optional function to call during evaluation. Eval doesn't do anything without this. :type eval_fn: Callable[BaseRLTrainer] """ pass
[docs] @abstractmethod def save(self, directory: Optional[str] = None): """Creates a checkpoint of training states""" pass
[docs] @abstractmethod def load(self, directory=None): """Loads a checkpoint created from `save`""" pass