Skip to content

Commit a216a3c

Browse files
chore(readme): update instructions (#93)
1 parent 59b3c65 commit a216a3c

File tree

2 files changed

+17
-37
lines changed

2 files changed

+17
-37
lines changed

Makefile

Lines changed: 0 additions & 14 deletions
This file was deleted.

README.md

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,47 @@
33

44
# Transformer Reinforcement Learning X
55

6-
`trlx` allows you to fine-tune 🤗 Hugging Face supported language models (`gpt2`, `gpt-j`, `gpt-neo` and `gpt-neox` based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization ([PPO](https://arxiv.org/pdf/1909.08593.pdf)) and Implicit Language Q-Learning ([ILQL](https://sea-snell.github.io/ILQL_site/)) are implemented.
6+
TRLX allows you to fine-tune 🤗 Hugging Face supported language models (`gpt2`, `gpt-j`, `gpt-neo` and `gpt-neox` based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization ([PPO](https://arxiv.org/pdf/1909.08593.pdf)) and Implicit Language Q-Learning ([ILQL](https://sea-snell.github.io/ILQL_site/)) are implemented.
77

8-
You can read more about trlX in our [documentation](https://trlX.readthedocs.io).
8+
You can read more about TRLX in our [documentation](https://trlX.readthedocs.io).
99

1010
## Installation
11-
### From Source
1211
```bash
1312
git clone https://github.com/CarperAI/trlx.git
1413
cd trlx
15-
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 # for cuda
14+
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
1615
pip install -e .
1716
```
1817

1918
## How to Train
20-
You can train your model using a reward function or a reward-labeled dataset.
19+
You can train a model using a reward function or a reward-labeled dataset.
2120

22-
### Using a reward function
21+
#### Using a reward function
2322
```python
24-
import trlx
25-
26-
# optimize some reward function
2723
model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])
28-
29-
# model is a wrapper with some logit preprocessing
30-
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
3124
```
32-
33-
### Using a reward-labeled dataset
34-
25+
#### Using a reward-labeled dataset
3526
```python
36-
import trlx
37-
38-
# Steer a model with a collection of rated samples
3927
model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])
28+
```
4029

41-
# model is a wrapper with some logit preprocessing
30+
#### Trained model is a wrapper over a given autoregressive model
31+
```python
4232
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
4333
```
4434

45-
### Using 🤗 Accelerate to speed up the training
46-
Launch distributed training with 🤗 Accelerate (only DeepSpeed integration is tested)
35+
#### Use 🤗 Accelerate to launch distributed training
4736

4837
```bash
49-
accelerate config
38+
accelerate config # choose DeepSpeed option
5039
accelerate launch examples/simulacra.py
5140
```
5241

42+
#### Use Ray Tune to launch hyperparameter sweep
43+
```bash
44+
python train_sweep.py --config configs/ray_tune_configs/ppo_config.yml --example-name ppo_sentiments
45+
```
46+
5347
For more usage see [examples](./examples)
5448

5549
## Contributing
@@ -59,4 +53,4 @@ and also read our [docs](https://trlX.readthedocs.io)
5953

6054
## Acknowledgements
6155

62-
Thanks Leandro for starting the original [trl](https://github.com/lvwerra/trl/)
56+
Many thanks to Leandro von Werra for hacking on the [trl](https://github.com/lvwerra/trl/)

0 commit comments

Comments
 (0)