Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ You can train a model using a reward function or a reward-labeled dataset.

#### Using a reward function
```python
model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])
trainer = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])
```
#### Using a reward-labeled dataset
```python
model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])
trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])
```

#### Trained model is a wrapper over a given autoregressive model
```python
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
```

#### Use 🤗 Accelerate to launch distributed training
Expand Down
2 changes: 1 addition & 1 deletion configs/ilql_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ train:

pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
trainer: "AccelerateILQLTrainer"

seed: 1000

model:
model_type: "AccelerateILQLModel"
model_path: "gpt2"
tokenizer_path: "gpt2"
num_layers_unfrozen: -1
Expand Down
2 changes: 1 addition & 1 deletion configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ train:

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
model_type: "AcceleratePPOModel"
model_path: "lvwerra/gpt2-imdb"
tokenizer_path: "gpt2"
num_layers_unfrozen: 2
Expand Down
2 changes: 1 addition & 1 deletion configs/ppo_gptj.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ train:

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
model_type: "AcceleratePPOModel"
model_path: "EleutherAI/gpt-j-6B"
tokenizer_path: "gpt2"
num_layers_unfrozen: 2
Expand Down
2 changes: 1 addition & 1 deletion configs/test_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ train:

pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load
trainer: "AcceleratePPOTrainer" # Name of model trainer to load

model:
model_type: "AcceleratePPOModel" # Name of accelerate model type to load
model_path: "lvwerra/gpt2-imdb" # Name of hf model to load
tokenizer_path: "gpt2" # Name of hf tokenizer to load
num_layers_unfrozen: 2 # Number of bottom layers to freeze during training
Expand Down
40 changes: 0 additions & 40 deletions docs/source/models.rst

This file was deleted.

40 changes: 40 additions & 0 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
.. _trainers:

RL Trainers
*******************

RL Trainers are what you're training with trlX. Currently, we support PPO and ILQL.
Note that new trainers must be registered with ``trlx.trainer.register_trainer``.

**General**

.. autoclass:: trlx.trainer.BaseRLTrainer
:members:

.. autoclass:: trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer
:members:

**PPO**

.. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer
:members:

.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMWithValueHead
:members:

.. autoclass:: trlx.trainer.nn.ppo_models.GPTModelBranch
:members:

.. autoclass:: trlx.trainer.nn.ppo_models.OPTModelBranch
:members:

.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMHydraWithValueHead
:members:

**ILQL**

.. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer
:members:

.. autoclass:: trlx.trainer.nn.ilql_models.CausalLMWithValueHeads
:members:
5 changes: 3 additions & 2 deletions examples/architext.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Toy example of optimizing textual interior designs to output the least number of rooms
# Also see https://architext.design/

import yaml
import trlx
from trlx.data.configs import TRLConfig


def reward_fn(samples):
Expand Down Expand Up @@ -32,7 +33,7 @@ def reward_fn(samples):
def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

model = trlx.train(
trlx.train(
"architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ train:

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
model_type: "AcceleratePPOModel"
model_path: "reshinthadith/codegen_350M_list_manip_5_len"
tokenizer_path: "reshinthadith/codegen_350M_list_manip_5_len"
num_layers_unfrozen: 2
Expand Down
4 changes: 2 additions & 2 deletions examples/experiments/grounded_program_synthesis/train_trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def main(hparams={}):
dataset = DSLDataset()
train_prompts = list(dataset.load_datapoints(split="train"))[:1000]

model = trlx.train(
trainer = trlx.train(
reward_fn=reward_fn,
prompts=train_prompts,
config=config,
)
model.save_pretrained("dataset/trained_model")
trainer.save_pretrained("dataset/trained_model")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the model types we use support save_pretrained?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Collaborator

@Dahoas Dahoas Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I don't think they do, or at least ppo doesn't. The base ppo model is just an nn.Module (not pretrained). It seems actually very annoying to save new model architectures in a huggingface format. We'll probably have to write a new config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhhh ok wait that's weird then can we just add a save pretrained function to PPO haha

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But anyway saving doesn't have anything to do with this pr so I think it's fine for now.



if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def reward_fn(samples: List[str]) -> List[float]:
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

model = trlx.train(
return trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 64,
Expand Down
2 changes: 1 addition & 1 deletion examples/randomwalks/configs/ilql_randomwalks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ train:

pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
trainer: "AccelerateILQLTrainer"

seed: 1000

model:
model_type: "AccelerateILQLModel"
model_path: "CarperAI/randomwalks"
tokenizer_path: "CarperAI/randomwalks"
num_layers_unfrozen: -1
Expand Down
2 changes: 1 addition & 1 deletion examples/randomwalks/configs/ppo_randomwalks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ train:

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
model_type: "AcceleratePPOModel"
model_path: "CarperAI/randomwalks"
tokenizer_path: "CarperAI/randomwalks"
num_layers_unfrozen: -1
Expand Down
2 changes: 1 addition & 1 deletion examples/simulacra.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

prompts, ratings = tuple(map(list, zip(*c.fetchall())))
model = trlx.train(
trlx.train(
"gpt2",
dataset=(prompts, ratings),
eval_prompts=["Hatsune Miku, Red Dress"] * 64,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_repo_trl_configs():
assert os.path.isfile(file), f"Config file {file} does not exist."
assert file.endswith(".yml"), f"Config file {file} is not a yaml file."
try:
TRLConfig.load_yaml(file)
config = TRLConfig.load_yaml(file)
assert config.train.entity_name is None, \
f"Unexpected entity name in config file `{file}`. Remove before pushing to repo."
except Exception as e:
assert False, f"Failed to load config file `{file}` with error `{e}`"
2 changes: 1 addition & 1 deletion tests/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from trlx.data.configs import TRLConfig
from trlx.model.nn.ppo_models import CausalLMHydraWithValueHead
from trlx.trainer.nn.ppo_models import CausalLMHydraWithValueHead
from trlx.utils.modeling import RunningMoments
from transformers import AutoTokenizer
import torch
Expand Down
9 changes: 4 additions & 5 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ class ModelConfig:
"""
Config for a model.

:param model_type: One of the registered RL models present in trlx.model
:type model_type: str

:param model_path: Path or name of the model (local or on huggingface hub)
:type model_path: str

Expand All @@ -38,7 +35,6 @@ class ModelConfig:
:type num_layers_unfrozen: int
"""

model_type: str
model_path: str
tokenizer_path: str
num_layers_unfrozen: int = -1
Expand Down Expand Up @@ -117,6 +113,8 @@ class TrainConfig:
:param orchestrator: Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator
:type orchestrator: str

:param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer

:param project_name: Project name for wandb
:type project_name: str

Expand All @@ -126,7 +124,7 @@ class TrainConfig:
:param checkpoint_dir: Directory to save checkpoints
:type checkpoint_dir: str

:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOModel.
:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOTrainer.
:type rollout_logging_dir: Optional[str]

:param seed: Random seed
Expand All @@ -143,6 +141,7 @@ class TrainConfig:

pipeline: str # One of the pipelines in framework.pipeline
orchestrator: str # One of the orchestrators
trainer: str # One of the trainers

project_name: str = "trlx"
entity_name: Optional[str] = None
Expand Down
6 changes: 3 additions & 3 deletions trlx/orchestrator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Dict

from trlx.model import BaseRLModel
from trlx.trainer import BaseRLTrainer
from trlx.pipeline import BasePipeline

# specifies a dictionary of architectures
Expand Down Expand Up @@ -33,9 +33,9 @@ def register_class(cls, name):

@register_orchestrator
class Orchestrator:
def __init__(self, pipeline: BasePipeline, rl_model: BaseRLModel):
def __init__(self, pipeline: BasePipeline, trainer: BaseRLTrainer):
self.pipeline = pipeline
self.rl_model = rl_model
self.trainer = trainer

@abstractmethod
def make_experience(self):
Expand Down
20 changes: 11 additions & 9 deletions trlx/orchestrator/offline_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ class OfflineOrchestrator(Orchestrator):
Orchestrator that creates a static dataset for offline training
"""

def __init__(self, model, split_token=None):
self.model = model
def __init__(self, trainer, split_token=None):
self.trainer = trainer
self.split_token = split_token

def make_experience(self, samples, rewards):
"""
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the model
"""
if self.model.tokenizer:
input_ids = self.model.tokenize(samples)
if self.trainer.tokenizer:
input_ids = self.trainer.tokenize(samples)
else:
input_ids = samples

Expand All @@ -31,7 +31,9 @@ def make_experience(self, samples, rewards):
# split samples on (prompts, continuations) on a given substring `split_token`
if self.split_token:
prompt_str_len = s.index(self.split_token) + len(self.split_token)
prompt_tok_len = len(self.model.tokenizer(s[:prompt_str_len]).input_ids)
prompt_tok_len = len(
self.trainer.tokenizer(s[:prompt_str_len]).input_ids
)
# else assume that the prompt is a bos token
else:
prompt_tok_len = 1
Expand All @@ -48,9 +50,9 @@ def make_experience(self, samples, rewards):
states_ixs.append(s_ixs)
dones.append(terminals)

if self.model.tokenizer:
prompt = self.model.tokenizer.decode(input_ids[0][: states_ixs[0][1]])
response = self.model.tokenizer.decode(input_ids[0][states_ixs[0][1] :])
if self.trainer.tokenizer:
prompt = self.trainer.tokenizer.decode(input_ids[0][: states_ixs[0][1]])
response = self.trainer.tokenizer.decode(input_ids[0][states_ixs[0][1] :])
print("[Sample example]")
print("Prompt: ", prompt)
print("Response: ", response)
Expand All @@ -69,6 +71,6 @@ def make_experience(self, samples, rewards):

attention_mask = [torch.ones(x.shape[0], dtype=int) for x in input_ids]

self.model.store = ILQLRolloutStorage(
self.trainer.store = ILQLRolloutStorage(
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones
)
Loading