Skip to content
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
765891c
🚀allow GRPO to connect to VLLM in remote/local node with NCCL communi…
binary-husky Mar 16, 2025
42f2131
Update trl/extras/remote_vllm_helper.py
binary-husky Mar 17, 2025
715d486
use argparse for options
kashif Mar 17, 2025
60a6753
add imports for remote vllm helper
kashif Mar 17, 2025
f784a8c
formatting
kashif Mar 17, 2025
5628b60
fix arguments
kashif Mar 17, 2025
8bfc313
use cli options
kashif Mar 17, 2025
d63c94a
vllm serve
qgallouedec Mar 18, 2025
c2e970f
clean server
qgallouedec Mar 18, 2025
e50d288
better naming
qgallouedec Mar 18, 2025
c723685
client
qgallouedec Mar 18, 2025
5d19cf1
style
qgallouedec Mar 18, 2025
b5ff472
new params in generate
qgallouedec Mar 18, 2025
e5fe142
this method is the new default
qgallouedec Mar 18, 2025
73853fc
update config
qgallouedec Mar 18, 2025
1fbdf69
Merge branch 'main' into main
qgallouedec Mar 18, 2025
94625f9
do not use asserts
kashif Mar 18, 2025
9335e68
update config
qgallouedec Mar 18, 2025
06aca0a
separate host and post
qgallouedec Mar 18, 2025
a7af2e2
Merge branch 'main' of https://github.com/binary-husky/trl into pr/bi…
qgallouedec Mar 18, 2025
a92b296
proper deprectation
qgallouedec Mar 18, 2025
714a833
deprecated arg in the vllm server
qgallouedec Mar 18, 2025
71024d6
simplify moving
qgallouedec Mar 18, 2025
bbf99f1
document host and port
qgallouedec Mar 18, 2025
a7e9dea
style
qgallouedec Mar 18, 2025
2b7fb1a
update trainer
qgallouedec Mar 18, 2025
5fee194
new generate args
qgallouedec Mar 18, 2025
508bd90
update doc
qgallouedec Mar 19, 2025
75bd4e3
Fix for zero3
qgallouedec Mar 19, 2025
5a8138c
Better naming
qgallouedec Mar 19, 2025
5f19c70
Remove remote_vllm_helper
qgallouedec Mar 19, 2025
4ae6cb4
remove grpo_with_remote_vllm
qgallouedec Mar 19, 2025
9ca4dde
remove cloudpickle from deps
qgallouedec Mar 19, 2025
5d1398e
Some consistency
qgallouedec Mar 19, 2025
e85c7bb
Merge branch 'main' into main
qgallouedec Mar 19, 2025
44ae792
Update docs/source/grpo_trainer.md
kashif Mar 19, 2025
060c4a6
Update setup.py
kashif Mar 19, 2025
d581a5f
add revision argument to vllm server
kashif Mar 19, 2025
128b503
Update docs/source/grpo_trainer.md
kashif Mar 19, 2025
724e013
Update docs/source/grpo_trainer.md
kashif Mar 19, 2025
daf2cde
Reset the prefix cache after updating weights
kashif Mar 19, 2025
75bfbc4
Merge remote-tracking branch 'refs/remotes/binary-husky/main'
kashif Mar 19, 2025
bb1fb55
Update vllm_client.py
qgallouedec Mar 19, 2025
415f3ca
Update vllm_client.py
qgallouedec Mar 19, 2025
1053197
Update vllm_serve.py
qgallouedec Mar 19, 2025
e6a4901
Add health check endpoint to vLLM server
qgallouedec Mar 19, 2025
e763064
connection timeout
qgallouedec Mar 19, 2025
4554af9
style
qgallouedec Mar 19, 2025
92a154f
fix doc langauge hint
qgallouedec Mar 19, 2025
666a6e4
Merge branch 'main' into main
kashif Mar 20, 2025
6537a7e
move reset_prefix_cache to its own endpoint
kashif Mar 20, 2025
821d37e
Merge branch 'main' of https://github.com/binary-husky/trl into pr/bi…
qgallouedec Mar 20, 2025
c38e79f
async
qgallouedec Mar 20, 2025
92cf3e0
merge peft adaptor to send to vllm
kashif Mar 20, 2025
0ffec9f
Looks simple. Wasn't.
qgallouedec Mar 20, 2025
d9d28db
Peft compatibility
qgallouedec Mar 21, 2025
d452a2f
Update docs/source/speeding_up_training.md
kashif Mar 21, 2025
dd873cf
Update docs/source/speeding_up_training.md
kashif Mar 21, 2025
7e11184
Update trl/extras/vllm_client.py
kashif Mar 21, 2025
6c4bf00
GatheredParameters can be disabled
kashif Mar 21, 2025
c431b0f
gather and ungather peft weights within the same deepseed context
kashif Mar 21, 2025
67c4e68
use is_vllm_available
kashif Mar 21, 2025
15fcaaf
minor consistency fixes
qgallouedec Mar 21, 2025
09ec2a1
fix error when deepspeed is not installed
kashif Mar 21, 2025
bc2f902
fix deepspeed import when not peft
kashif Mar 21, 2025
db8d5fd
Merge branch 'main' of https://github.com/binary-husky/trl into pr/bi…
qgallouedec Mar 21, 2025
89812ee
simpler
kashif Mar 21, 2025
bb66c91
multinode doc
qgallouedec Mar 21, 2025
b23c23f
minor code and comments changes
qgallouedec Mar 21, 2025
5111c8f
Merge branch 'main' of https://github.com/binary-husky/trl into pr/bi…
qgallouedec Mar 21, 2025
657cb21
style
qgallouedec Mar 21, 2025
8670c35
optional deps
qgallouedec Mar 21, 2025
7955a39
vllm_server_timeout as arg
qgallouedec Mar 21, 2025
5a37647
small refinement in doc
qgallouedec Mar 21, 2025
10d26ef
update deps
qgallouedec Mar 21, 2025
d759c9c
Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer…
binary-husky Mar 21, 2025
4fc8790
Revert "Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm …
qgallouedec Mar 21, 2025
fb28f62
log num_tokens
qgallouedec Mar 21, 2025
3a211e6
disable vllm test (in the future we'll add a mock for vllm server for…
qgallouedec Mar 21, 2025
7a81655
style
qgallouedec Mar 21, 2025
716a822
fix ds3_gather_for_generation
qgallouedec Mar 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ At each training step, we sample a batch of prompts and generate a set of \\( G

### Computing the advantage

For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:

$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$

This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.

### Estimating the KL divergence

Expand All @@ -83,15 +83,15 @@ $$

### Computing the loss

The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:

$$
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$

where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.

In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:

$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
Expand All @@ -112,17 +112,26 @@ The GRPO Trainer logs the following metrics:

## Customization

## Speed up training with vLLM-powered generation
### Speed up training with vLLM-powered generation

Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, pass `use_vllm=True` in the training arguments.
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, first install the package with

```shell
pip install "vllm==0.7.2"

```bash
trl vllm-serve --model <model_name>
```

Then, pass `use_vllm=True` in the training arguments and run the training script:

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_vllm=True)
```
```

For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).

### Using a custom reward function

Expand Down
65 changes: 47 additions & 18 deletions docs/source/speeding_up_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,68 @@ training_args = OnlineDPOConfig(..., use_vllm=True)
</hfoption>
<hfoption id="GRPO">

Then, enable it by passing `use_vllm=True` in the training arguments.
First, start a vLLM server by running:

```bash
trl vllm-serve --model <model_name>
```

Then, run the training script and pass `use_vllm=True` in the training arguments.

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_vllm=True)
```

The strategy here is to use a dedicated GPU for generation powered by vLLM, while using the remainder for training.
You can customize the server configuration by passing additional arguments.

```sh
$ trl vllm-serve --help
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--host HOST]
[--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE]
[--max_model_len MAX_MODEL_LEN] [--enable_prefix_caching ENABLE_PREFIX_CACHING]

options:
-h, --help Show this help message and exit
--model MODEL Model name or path to load the model from. (default: None)
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
Number of tensor parallel workers to use. (default: 1)
--host HOST Host address to run the server on. (default: 0.0.0.0)
--port PORT Port to run the server on. (default: 8000)
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device
dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the
model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during
initialization. (default: 0.9)
--dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on
the model configuration. Find the supported values in the vLLM documentation. (default: auto)
--max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context
size, which might be much larger than the KV cache, leading to inefficiencies. (default: None)
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this
feature. (default: None)
```

<Tip warning={true}>

When using vLLM, an additional GPU is required exclusively for generation. This means you need at least two available GPUs and must ensure that one remains unused by the trainer. To achieve this, run the training with `--num_processes <NUMBER_OF_GPUs - 1>`.
When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.

For example, if you have 4 GPUs, set `--num_processes 3` to allocate three GPUs for training while reserving one for generation.
```bash
accelerate launch --multi_gpu --num_processes 3 train_grpo.py
```
Set GPUs **0-3** for vLLM generation:
```sh
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
```

![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/1_gpu_for_generation.png)
And GPUs **4-7** for training:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```

</Tip>

You can further tune the vLLM configuration by setting a specific `vllm_device` and `vllm_gpu_memory_utilization` in the [`GRPOConfig`].

```python
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_device="cuda:4",
vllm_gpu_memory_utilization=0.7,
)
```

</hfoption>
</hfoptions>
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
# vllm is not available on Windows
# vllm 0.7.3 causes hanging while gathering. temporary pinning the version until the issue is resolved
"vllm": ["vllm==0.7.2; sys_platform != 'win32'"],
"vllm": ["vllm==0.7.2; sys_platform != 'win32'", "fastapi", "requests"],
"vlm": ["Pillow"],
}
EXTRAS["dev"] = []
Expand Down
7 changes: 7 additions & 0 deletions trl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from .scripts.kto import make_parser as make_kto_parser
from .scripts.sft import make_parser as make_sft_parser
from .scripts.utils import TrlParser
from .scripts.vllm_serve import main as vllm_serve_main
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser


def main():
Expand All @@ -40,6 +42,7 @@ def main():
make_grpo_parser(subparsers)
make_kto_parser(subparsers)
make_sft_parser(subparsers)
make_vllm_serve_parser(subparsers)

# Parse the arguments
args = parser.parse_args()
Expand Down Expand Up @@ -87,6 +90,10 @@ def main():
args.training_script_args = sys.argv[2:] # remove "trl" and "sft"
launch_command(args) # launch training

elif args.command == "vllm-serve":
(script_args,) = parser.parse_args_and_config()
vllm_serve_main(script_args)


if __name__ == "__main__":
main()
Loading