Source code for trlx.data.ilql_types
from dataclasses import dataclass, fields
from torchtyping import TensorType # type: ignore
def flatten_dataclass(cls: type):
"""Return a function that flattens a dataclass into a list"""
cls_fields = [f.name for f in fields(cls)]
return lambda x: [getattr(x, f) for f in cls_fields]
def unflatten_dataclass(cls: type):
"""Return a function that unflattens a list into a dataclass"""
cls_fields = [f.name for f in fields(cls)]
return lambda x: cls(**dict(zip(cls_fields, x)))
[docs]@dataclass
class ILQLElement:
"""
Data element for ILQL
:param input_ids: Input tokens. Should be a long tensor.
:type input_ids: torch.Tensor
:param attention_mask: Attention mask. Should be a long tensor.
:type attention_mask: torch.Tensor
:param rewards: Rewards for each token. Should be a float tensor of same size as tokens.
:type rewards: torch.Tensor
"""
input_ids: TensorType["query_size"]
attention_mask: TensorType["query_size"]
rewards: TensorType["reward_size"]
states_ixs: TensorType["states_size"]
actions_ixs: TensorType["reward_size"]
dones: TensorType["states_size"]
@dataclass
class ILQLSeq2SeqElement:
"""
Data element for ILQL
:param input_ids: Input tokens. Should be a long tensor.
:type input_ids: torch.Tensor
:param attention_mask: Attention mask. Should be a long tensor.
:type attention_mask: torch.Tensor
:param rewards: Rewards for each token. Should be a float tensor of same size as tokens.
:type rewards: torch.Tensor
"""
input_ids: TensorType["query_size"]
attention_mask: TensorType["query_size"]
decoder_input_ids: TensorType["reward_size"]
rewards: TensorType["reward_size"]
states_ixs: TensorType["states_size"]
actions_ixs: TensorType["reward_size"]
dones: TensorType["states_size"]
[docs]@dataclass
class ILQLBatch:
"""
Batched ILQL data elements
:param input_ids: Batch of input tokens.
:type input_ids: torch.Tensor
:param attention_mask: Batch of attention masks.
:type attention_mask: torch.Tensor
:param rewards: Batch of rewards for each token in each token batch.
:type rewards: torch.Tensor
"""
input_ids: TensorType["batch_size", "query_size"]
attention_mask: TensorType["batch_size", "query_size"]
rewards: TensorType["batch_size", "reward_size"]
states_ixs: TensorType["batch_size", "states_size"]
actions_ixs: TensorType["batch_size", "reward_size"]
dones: TensorType["batch_size", "states_size"]
@dataclass
class ILQLSeq2SeqBatch:
"""
Batched ILQL data elements
:param input_ids: Batch of input tokens.
:type input_ids: torch.Tensor
:param attention_mask: Batch of attention masks.
:type attention_mask: torch.Tensor
:param rewards: Batch of rewards for each token in each token batch.
:type rewards: torch.Tensor
"""
input_ids: TensorType["batch_size", "query_size"]
attention_mask: TensorType["batch_size", "query_size"]
decoder_input_ids: TensorType["batch_size", "reward_size"]
rewards: TensorType["batch_size", "reward_size"]
states_ixs: TensorType["batch_size", "states_size"]
actions_ixs: TensorType["batch_size", "reward_size"]
dones: TensorType["batch_size", "states_size"]