Skip to content

Commit 06cd30f

Browse files
Simplify api (#24)
* fix(ilql): sampling on variable sized prompts & stage simplified api * Save strategy (#23) * Had to add py_modules=trlx to setup. * Added a save strategy. * Cleaned up a few things. * Added save_steps to ilql_config.yaml and save steps strategy to accelerate_ilql_model.py for consistency. The save_steps parameter must be set now because of how TrainConfig.from_dict operates. If not save_steps parameter is given in the configs it throws an error. * Adding mininal changes to enable step based save strategy in configs/ppo_config.yml, trlx/data/configs.py, and trlx/model_accelerate_ppo_model.py * Some problems crept in despite merge check. This fixes them. * Realized I am merging into stage-api not main so fixed an issue with ilql_config.yml * fix(ilql): eval on a set of betas & add simple timers * fix: saving checkpoints * refactor(ilql): subsume under base_model * fix(ilql): mask prompts * merge hydra * fix(ppo): generalize and stage for api * feat: add architext examples * fix(ppo,ilql): ddp + accelerate * refactor: clean pipelines * feat: add simulacra example * fix(ppo): single token prompting * refactor: fully merge models * refactor(configs): lower batch_sizes & remove dead entries * refactor(examples): update for new api * fix(tests,style): one way to pass tests is to change them * fix(ppo): warnings of the most recent version of transformers 4.23.1 complains if .generate() starts with single bos token, when bos=eos=pad token * refactor(readme): add api * chore: add doc strings * fix: remove dropout * chore: keep gpt2 small in examples * chore: revert to previous default configs * chore(docs): rename classes, remove unused, add examples * chore(readme): add contributing.md & deepspeed note * style(readme): US spelling * chore(examples): add explanations for each task
1 parent 4ff712b commit 06cd30f

30 files changed

+1088
-1253
lines changed

README.md

Lines changed: 24 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,152 +1,45 @@
11
[docs-image]: https://readthedocs.org/projects/trlX/badge/?version=latest
22
[docs-url]: https://trlX.readthedocs.io/en/latest/?badge=latest
33

4-
# Welcome to Transformer Reinforcement Learning X (`trlX`)
5-
> A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
4+
# Transformer Reinforcement Learning X
65

7-
[![Docs Status][docs-image]][docs-url]
6+
`trlx` allows you to fine-tune 🤗 Hugging Face supported language models (`gpt2`, `gpt-j`, `gpt-neo` and `gpt-neox` based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization ([PPO](https://arxiv.org/pdf/1909.08593.pdf)) and Implicit Language Q-Learning ([ILQL](https://sea-snell.github.io/ILQL_site/)) are implemented.
87

9-
**[Documentation](https://trlX.readthedocs.io)**
8+
## Train
109

11-
## Overview
12-
Inspired by the popular `trl` library, the `trlX` repo allows you to fine-tune Huggingface supported language models up to 20B parameters via either reinforcement learning using a provided scoring function or reward-labeled dataset. We aim to support a range of both online and offline RL algorithms including Proximal Policy Optimization (PPO), Natural Language Policy Optimization (NLPO), Actor Critic (A2C), and Implicit Q Learning (ILQL).
13-
14-
The library supports `gpt2` and `gptj` with plans to include `GPT-NeoX`, `T5` and more. PPO and ILQL algorithms are implemented. Disibtributed training has been implemented via HF Accelerate and tested up to two nodes, each with 8 gpus.
15-
16-
## Structure
17-
18-
The training pipeline is broken into four pieces:
10+
```python
11+
import trlx
1912

20-
- Prompt pipeline: Handles loading of prompts/text used to prompt model for exploration in online methods
21-
- Rollout pipeline: Handles loading and storage of reward labeled data used
22-
- Orchestrator: Handles exploration/rollout collection of online methods. Pushes collected rollouts to the rollout pipeline.
23-
- Model: Wraps the supplied base model (ex: `gpt2`) and implements the desired training method loss (ex: PPO).
13+
# optimize some reward function
14+
model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])
2415

25-
Adding a task for RLHF training depends on the desired training method and pre-existing data. If we are online and have no reward labeled data this is as simple as writing a new prompt pipeline, which supplies prompts for exploration, and a new reward function to be passed into the `PPOOrchestrator` class.
16+
# or steer a model with a collection of rated samples
17+
model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])
2618

27-
## Installation
28-
```bash
29-
git clone https://github.com/CarperAI/trlx.git
30-
cd trlx
31-
pip install -e ".[dev]"
32-
pre-commit install # see .pre-commit-config.yaml
19+
# model is a wrapper with some logit preprocessing
20+
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
3321
```
3422

35-
## Example: How to add a task
36-
37-
In the below we implement a sentiment learning task.
38-
39-
### Configure `accelerate`
23+
Launch distributed training with 🤗 Accelerate (only DeepSpeed integration is tested)
4024

4125
```bash
4226
accelerate config
27+
accelerate launch examples/simulacra.py
4328
```
4429

45-
### Implement a prompt pipeline
46-
47-
```python
48-
@register_datapipeline
49-
class PPOPipeline(BasePipeline):
50-
def __init__(self, tokenizer, config, prompt_dataset_path=None):
51-
super().__init__()
52-
53-
ds = load_dataset("imdb", split="test")
54-
ds = ds.rename_columns({"text": "review", "label": "sentiment"})
55-
ds = ds.filter(lambda x: len(x["review"]) < 500, batched=False)
56-
57-
self.tokens = [
58-
tokenizer(
59-
text,
60-
truncation=True,
61-
padding="max_length",
62-
max_length=config.train.input_size,
63-
return_tensors="pt",
64-
)["input_ids"]
65-
.long()
66-
.flatten()
67-
for text in ds["review"]
68-
]
69-
self.text = [tokenizer.decode(tokens.tolist()) for tokens in self.tokens]
70-
71-
def __getitem__(self, index: int) -> PromptElement:
72-
return PromptElement(self.text[index], self.tokens[index])
73-
74-
def __len__(self) -> int:
75-
return len(self.text)
76-
77-
def create_loader(
78-
self,
79-
batch_size: int,
80-
shuffle: bool,
81-
prep_fn: Callable = None,
82-
num_workers: int = 0,
83-
) -> DataLoader:
84-
# TODO(dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly
85-
def collate_fn(elems: Iterable[PromptElement]) -> PromptElement:
86-
return PromptBatch(
87-
[elem.text for elem in elems],
88-
torch.stack(
89-
[elem.tokens for elem in elems]
90-
), # Assumes token tensors all same size
91-
)
30+
For more usage see [examples](./examples)
9231

93-
return DataLoader(
94-
self, batch_size, shuffle, collate_fn=collate_fn, num_workers=num_workers
95-
)
96-
```
97-
98-
### Launch training
99-
100-
```python
101-
from typing import List
102-
103-
import torch
104-
from transformers import pipeline
105-
106-
import wandb
107-
from trlx.data.configs import TRLConfig
108-
from trlx.model.accelerate_ppo_model import AcceleratePPOModel
109-
from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator
110-
from trlx.pipeline.ppo_pipeline import PPOPipeline
111-
from trlx.utils.loading import get_model, get_orchestrator, get_pipeline
112-
113-
if __name__ == "__main__":
114-
cfg = TRLConfig.load_yaml("configs/ppo_config.yml")
115-
116-
sentiment_pipe = pipeline(
117-
"sentiment-analysis", "lvwerra/distilbert-imdb", device=-1
118-
)
119-
120-
def reward_fn(samples: List[str]):
121-
sent_kwargs = {
122-
"return_all_scores": True,
123-
"function_to_apply": None,
124-
"batch_size": cfg.method.chunk_size,
125-
}
126-
pipe_outputs = sentiment_pipe(samples, **sent_kwargs)
127-
scores = torch.tensor([output[1]["score"] for output in pipe_outputs])
128-
return scores
129-
130-
model: AcceleratePPOModel = get_model(cfg.model.model_type)(cfg)
131-
if model.accelerator.is_main_process:
132-
wandb.watch(model.model)
133-
134-
pipeline: PPOPipeline = get_pipeline(cfg.train.pipeline)(model.tokenizer, cfg)
135-
orch: PPOOrchestrator = get_orchestrator(cfg.train.orchestrator)(
136-
model, pipeline, reward_fn=reward_fn, chunk_size=cfg.method.chunk_size
137-
)
138-
orch.make_experience(cfg.method.num_rollouts)
139-
model.learn()
140-
141-
print("DONE!")
32+
## Install
33+
```bash
34+
git clone https://github.com/CarperAI/trlx.git
35+
cd trlx
36+
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 # for cuda
37+
pip install -e .
14238
```
14339

144-
And run `accelerate launch my_script.py`
145-
146-
## References
40+
For development check out these [guidelines](./CONTRIBUTING.md)
41+
and also read our [docs](https://trlX.readthedocs.io)
14742

148-
### Proximal Policy Optimisation
149-
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
43+
## Acknowledgements
15044

151-
### Language models
152-
The language models utilize the `transformers` library by 🤗 Hugging Face.
45+
Thanks Leandro for starting the original [trl](https://github.com/lvwerra/trl/)

configs/ilql_config.yml

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,36 @@
11
model:
22
model_path : "gpt2"
3-
model_type : "ILQLModel"
4-
device : "cuda"
53
tokenizer_path: "gpt2"
4+
model_type : "ILQLModel"
65
num_layers_unfrozen: -1
76

87
train:
9-
n_ctx : 512
10-
epochs : 1
11-
total_steps : 80000
12-
batch_size : 80
13-
grad_clip : 1.0
8+
seq_length: 64
9+
batch_size: 128
10+
epochs: 10
11+
total_steps: 10000
1412

15-
lr_ramp_steps : 100
16-
lr_decay_steps : 3366
17-
weight_decay : 1.0e-6
18-
learning_rate_init : 1.0e-3
19-
learning_rate_target : 1.0e-3
13+
lr_ramp_steps: 100
14+
lr_decay_steps: 3366
15+
weight_decay: 1e-6
16+
learning_rate_init: 1e-4
17+
learning_rate_target: 1e-4
18+
opt_betas: [0.9, 0.95]
2019

21-
log_interval : 25
22-
checkpoint_interval : 100
23-
eval_interval : 50
24-
25-
input_size: 1
26-
gen_size: 32
20+
checkpoint_interval: 1000
21+
eval_interval: 16
2722

2823
pipeline : "OfflinePipeline"
2924
orchestrator : "OfflineOrchestrator"
30-
31-
accelerate : true
25+
seed: 1000
3226

3327
method:
3428
name: "ilqlconfig"
3529
tau: 0.7
3630
gamma: 0.99
3731
cql_scale: 0.1
3832
awac_scale: 1
39-
alpha: 1
40-
steps_for_target_q_sync: 10
41-
beta: 4
33+
alpha: 0.005
34+
steps_for_target_q_sync: 1
35+
betas: [16]
4236
two_qs: true

configs/ppo_config.yml

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,44 @@
11
model:
2-
model_path : "lvwerra/gpt2-imdb" # Name of hf model to load
3-
tokenizer_path : "gpt2" # Name of hf tokenizer to load
4-
model_type : "AcceleratePPOModel" # Name of accelerate model type to load
5-
device : "cuda" # Train device
6-
num_layers_unfrozen : 2 # Number of bottom layers to freeze during training
2+
model_path: "lvwerra/gpt2-imdb" # Name of hf model to load
3+
tokenizer_path: "gpt2" # Name of hf tokenizer to load
4+
model_type: "AcceleratePPOModel" # Name of accelerate model type to load
5+
num_layers_unfrozen: 2 # Number of bottom layers to freeze during training
76

87
train:
9-
n_ctx : 512 # Size of LM context
10-
epochs : 10 # Train for max(epochs, total_steps)
11-
total_steps : 80000 # Train for max(epochs, total_steps)
12-
batch_size : 128 # batch size
13-
grad_clip : 1.0 # gradient clipping threshold
8+
seq_length: 48 # Size of LM context
9+
epochs: 1000 # Train for max(epochs, total_steps)
10+
total_steps: 10000 # Train for max(epochs, total_steps)
11+
batch_size: 128 # batch size
1412

15-
lr_ramp_steps : 100 # learning rate warm up
16-
lr_decay_steps : 79000 # learning rate decay
17-
weight_decay : 1.0e-6 # weight decay param
18-
learning_rate_init : 1.412e-4 # init learning rate
19-
learning_rate_target : 1.412e-4 # target final learning rate
13+
lr_ramp_steps: 100 # learning rate warm up
14+
lr_decay_steps: 79000 # learning rate decay
15+
weight_decay: 1.0e-6 # weight decay param
16+
learning_rate_init: 1.412e-4 # init learning rate
17+
learning_rate_target: 1.412e-4 # target final learning rate
18+
opt_betas: [0.9, 0.95] # adam betas
2019

21-
log_interval : 25 # log interval
22-
checkpoint_interval : 1000000 # checkpoint interval
23-
eval_interval : 16 # eval interval
20+
checkpoint_interval: 10000 # checkpoint interval
21+
eval_interval: 16 # eval interval
2422

25-
pipeline : "PPOPipeline" # prompt pipeline to load
26-
orchestrator : "PPOOrchestrator" # orchestrator to load
27-
28-
input_size : 4 # max input size
29-
gen_size : 48 # max gen size
30-
31-
accelerate : True # Use accelerate
32-
accelerate_config_path : "" # Path to accelerate config(for logging purposes)
23+
pipeline: "PPOPipeline" # prompt pipeline to load
24+
orchestrator: "PPOOrchestrator" # orchestrator to load
3325

3426
method:
35-
name : 'ppoconfig' # Name of RL method config
36-
num_rollouts : 128 # Number of rollouts to collect per epoch
37-
chunk_size : 128 # Number of rollouts to collect in one loop of orchestrator
38-
ppo_epochs : 4 # Number of ppo epochs
39-
init_kl_coef : 0.2 # init kl coefficient
40-
target : 6 # target kl coefficient, set None for fixed kl coef
41-
horizon : 10000 # PPO horizon
42-
gamma : 1 # PPO discount
43-
lam : 0.95 # PPO lambda
44-
cliprange : 0.2 # clip range
45-
cliprange_value : 0.2 # clip range
46-
vf_coef : 2.3 # value term weight
47-
gen_kwargs :
48-
max_length : 48 # LM max sample gen length
49-
min_length : 48 # LM min sample gen length
50-
top_k : 0.0 # top k
51-
top_p : 1.0 # top p
52-
do_sample : True # sample
27+
name: 'ppoconfig' # Name of RL method config
28+
num_rollouts: 128 # Number of rollouts to collect per epoch
29+
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
30+
ppo_epochs: 4 # Number of ppo epochs
31+
init_kl_coef: 0.2 # init kl coefficient
32+
target: 6 # target kl coefficient, set None for fixed kl coef
33+
horizon: 10000 # PPO horizon
34+
gamma: 1 # PPO discount
35+
lam: 0.95 # PPO lambda
36+
cliprange: 0.2 # clip range
37+
cliprange_value: 0.2 # clip range
38+
vf_coef: 2.3 # value term weight
39+
gen_kwargs:
40+
max_length: 48 # LM max sample gen length
41+
min_length: 48 # LM min sample gen length
42+
top_k: 0.0 # top k
43+
top_p: 1.0 # top p
44+
do_sample: True # sample

0 commit comments

Comments
 (0)