|
1 | 1 | [docs-image]: https://readthedocs.org/projects/trlX/badge/?version=latest
|
2 | 2 | [docs-url]: https://trlX.readthedocs.io/en/latest/?badge=latest
|
3 | 3 |
|
4 |
| -# Welcome to Transformer Reinforcement Learning X (`trlX`) |
5 |
| -> A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) |
| 4 | +# Transformer Reinforcement Learning X |
6 | 5 |
|
7 |
| -[![Docs Status][docs-image]][docs-url] |
| 6 | +`trlx` allows you to fine-tune 🤗 Hugging Face supported language models (`gpt2`, `gpt-j`, `gpt-neo` and `gpt-neox` based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization ([PPO](https://arxiv.org/pdf/1909.08593.pdf)) and Implicit Language Q-Learning ([ILQL](https://sea-snell.github.io/ILQL_site/)) are implemented. |
8 | 7 |
|
9 |
| -**[Documentation](https://trlX.readthedocs.io)** |
| 8 | +## Train |
10 | 9 |
|
11 |
| -## Overview |
12 |
| -Inspired by the popular `trl` library, the `trlX` repo allows you to fine-tune Huggingface supported language models up to 20B parameters via either reinforcement learning using a provided scoring function or reward-labeled dataset. We aim to support a range of both online and offline RL algorithms including Proximal Policy Optimization (PPO), Natural Language Policy Optimization (NLPO), Actor Critic (A2C), and Implicit Q Learning (ILQL). |
13 |
| - |
14 |
| -The library supports `gpt2` and `gptj` with plans to include `GPT-NeoX`, `T5` and more. PPO and ILQL algorithms are implemented. Disibtributed training has been implemented via HF Accelerate and tested up to two nodes, each with 8 gpus. |
15 |
| - |
16 |
| -## Structure |
17 |
| - |
18 |
| -The training pipeline is broken into four pieces: |
| 10 | +```python |
| 11 | +import trlx |
19 | 12 |
|
20 |
| -- Prompt pipeline: Handles loading of prompts/text used to prompt model for exploration in online methods |
21 |
| -- Rollout pipeline: Handles loading and storage of reward labeled data used |
22 |
| -- Orchestrator: Handles exploration/rollout collection of online methods. Pushes collected rollouts to the rollout pipeline. |
23 |
| -- Model: Wraps the supplied base model (ex: `gpt2`) and implements the desired training method loss (ex: PPO). |
| 13 | +# optimize some reward function |
| 14 | +model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples]) |
24 | 15 |
|
25 |
| -Adding a task for RLHF training depends on the desired training method and pre-existing data. If we are online and have no reward labeled data this is as simple as writing a new prompt pipeline, which supplies prompts for exploration, and a new reward function to be passed into the `PPOOrchestrator` class. |
| 16 | +# or steer a model with a collection of rated samples |
| 17 | +model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)]) |
26 | 18 |
|
27 |
| -## Installation |
28 |
| -```bash |
29 |
| -git clone https://github.com/CarperAI/trlx.git |
30 |
| -cd trlx |
31 |
| -pip install -e ".[dev]" |
32 |
| -pre-commit install # see .pre-commit-config.yaml |
| 19 | +# model is a wrapper with some logit preprocessing |
| 20 | +model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True) |
33 | 21 | ```
|
34 | 22 |
|
35 |
| -## Example: How to add a task |
36 |
| - |
37 |
| -In the below we implement a sentiment learning task. |
38 |
| - |
39 |
| -### Configure `accelerate` |
| 23 | +Launch distributed training with 🤗 Accelerate (only DeepSpeed integration is tested) |
40 | 24 |
|
41 | 25 | ```bash
|
42 | 26 | accelerate config
|
| 27 | +accelerate launch examples/simulacra.py |
43 | 28 | ```
|
44 | 29 |
|
45 |
| -### Implement a prompt pipeline |
46 |
| - |
47 |
| -```python |
48 |
| -@register_datapipeline |
49 |
| -class PPOPipeline(BasePipeline): |
50 |
| - def __init__(self, tokenizer, config, prompt_dataset_path=None): |
51 |
| - super().__init__() |
52 |
| - |
53 |
| - ds = load_dataset("imdb", split="test") |
54 |
| - ds = ds.rename_columns({"text": "review", "label": "sentiment"}) |
55 |
| - ds = ds.filter(lambda x: len(x["review"]) < 500, batched=False) |
56 |
| - |
57 |
| - self.tokens = [ |
58 |
| - tokenizer( |
59 |
| - text, |
60 |
| - truncation=True, |
61 |
| - padding="max_length", |
62 |
| - max_length=config.train.input_size, |
63 |
| - return_tensors="pt", |
64 |
| - )["input_ids"] |
65 |
| - .long() |
66 |
| - .flatten() |
67 |
| - for text in ds["review"] |
68 |
| - ] |
69 |
| - self.text = [tokenizer.decode(tokens.tolist()) for tokens in self.tokens] |
70 |
| - |
71 |
| - def __getitem__(self, index: int) -> PromptElement: |
72 |
| - return PromptElement(self.text[index], self.tokens[index]) |
73 |
| - |
74 |
| - def __len__(self) -> int: |
75 |
| - return len(self.text) |
76 |
| - |
77 |
| - def create_loader( |
78 |
| - self, |
79 |
| - batch_size: int, |
80 |
| - shuffle: bool, |
81 |
| - prep_fn: Callable = None, |
82 |
| - num_workers: int = 0, |
83 |
| - ) -> DataLoader: |
84 |
| - # TODO(dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly |
85 |
| - def collate_fn(elems: Iterable[PromptElement]) -> PromptElement: |
86 |
| - return PromptBatch( |
87 |
| - [elem.text for elem in elems], |
88 |
| - torch.stack( |
89 |
| - [elem.tokens for elem in elems] |
90 |
| - ), # Assumes token tensors all same size |
91 |
| - ) |
| 30 | +For more usage see [examples](./examples) |
92 | 31 |
|
93 |
| - return DataLoader( |
94 |
| - self, batch_size, shuffle, collate_fn=collate_fn, num_workers=num_workers |
95 |
| - ) |
96 |
| - ``` |
97 |
| - |
98 |
| -### Launch training |
99 |
| - |
100 |
| -```python |
101 |
| -from typing import List |
102 |
| - |
103 |
| -import torch |
104 |
| -from transformers import pipeline |
105 |
| - |
106 |
| -import wandb |
107 |
| -from trlx.data.configs import TRLConfig |
108 |
| -from trlx.model.accelerate_ppo_model import AcceleratePPOModel |
109 |
| -from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator |
110 |
| -from trlx.pipeline.ppo_pipeline import PPOPipeline |
111 |
| -from trlx.utils.loading import get_model, get_orchestrator, get_pipeline |
112 |
| - |
113 |
| -if __name__ == "__main__": |
114 |
| - cfg = TRLConfig.load_yaml("configs/ppo_config.yml") |
115 |
| - |
116 |
| - sentiment_pipe = pipeline( |
117 |
| - "sentiment-analysis", "lvwerra/distilbert-imdb", device=-1 |
118 |
| - ) |
119 |
| - |
120 |
| - def reward_fn(samples: List[str]): |
121 |
| - sent_kwargs = { |
122 |
| - "return_all_scores": True, |
123 |
| - "function_to_apply": None, |
124 |
| - "batch_size": cfg.method.chunk_size, |
125 |
| - } |
126 |
| - pipe_outputs = sentiment_pipe(samples, **sent_kwargs) |
127 |
| - scores = torch.tensor([output[1]["score"] for output in pipe_outputs]) |
128 |
| - return scores |
129 |
| - |
130 |
| - model: AcceleratePPOModel = get_model(cfg.model.model_type)(cfg) |
131 |
| - if model.accelerator.is_main_process: |
132 |
| - wandb.watch(model.model) |
133 |
| - |
134 |
| - pipeline: PPOPipeline = get_pipeline(cfg.train.pipeline)(model.tokenizer, cfg) |
135 |
| - orch: PPOOrchestrator = get_orchestrator(cfg.train.orchestrator)( |
136 |
| - model, pipeline, reward_fn=reward_fn, chunk_size=cfg.method.chunk_size |
137 |
| - ) |
138 |
| - orch.make_experience(cfg.method.num_rollouts) |
139 |
| - model.learn() |
140 |
| - |
141 |
| - print("DONE!") |
| 32 | +## Install |
| 33 | +```bash |
| 34 | +git clone https://github.com/CarperAI/trlx.git |
| 35 | +cd trlx |
| 36 | +pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 # for cuda |
| 37 | +pip install -e . |
142 | 38 | ```
|
143 | 39 |
|
144 |
| -And run `accelerate launch my_script.py` |
145 |
| - |
146 |
| -## References |
| 40 | +For development check out these [guidelines](./CONTRIBUTING.md) |
| 41 | +and also read our [docs](https://trlX.readthedocs.io) |
147 | 42 |
|
148 |
| -### Proximal Policy Optimisation |
149 |
| -The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)]. |
| 43 | +## Acknowledgements |
150 | 44 |
|
151 |
| -### Language models |
152 |
| -The language models utilize the `transformers` library by 🤗 Hugging Face. |
| 45 | +Thanks Leandro for starting the original [trl](https://github.com/lvwerra/trl/) |
0 commit comments