Skip to content

Commit c9e8dda

Browse files
binary-huskykashifqgallouedeclewtun
authored
🚀 Scaling GRPO to 70B+ Models and Multi-Node Training with vLLM Server & NCCL Communication (huggingface#3094)
* 🚀allow GRPO to connect to VLLM in remote/local node with NCCL communication * Update trl/extras/remote_vllm_helper.py Co-authored-by: Kashif Rasul <[email protected]> * use argparse for options * add imports for remote vllm helper * formatting * fix arguments * use cli options * vllm serve * clean server * better naming * client * style * new params in generate * this method is the new default * update config * do not use asserts * update config * separate host and post * proper deprectation * deprecated arg in the vllm server * simplify moving * document host and port * style * update trainer * new generate args * update doc * Fix for zero3 * Better naming * Remove remote_vllm_helper * remove grpo_with_remote_vllm * remove cloudpickle from deps * Some consistency * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <[email protected]> * Update setup.py Co-authored-by: lewtun <[email protected]> * add revision argument to vllm server * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <[email protected]> * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <[email protected]> * Reset the prefix cache after updating weights * Update vllm_client.py * Update vllm_client.py * Update vllm_serve.py * Add health check endpoint to vLLM server * connection timeout * style * fix doc langauge hint * move reset_prefix_cache to its own endpoint * async * merge peft adaptor to send to vllm * Looks simple. Wasn't. * Peft compatibility * Update docs/source/speeding_up_training.md Co-authored-by: lewtun <[email protected]> * Update docs/source/speeding_up_training.md Co-authored-by: lewtun <[email protected]> * Update trl/extras/vllm_client.py Co-authored-by: lewtun <[email protected]> * GatheredParameters can be disabled * gather and ungather peft weights within the same deepseed context * use is_vllm_available * minor consistency fixes * fix error when deepspeed is not installed * fix deepspeed import when not peft * simpler * multinode doc * minor code and comments changes * style * optional deps * vllm_server_timeout as arg * small refinement in doc * update deps * Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution * Revert "Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution" This reverts commit d759c9c. * log num_tokens * disable vllm test (in the future we'll add a mock for vllm server for them) * style * fix ds3_gather_for_generation --------- Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: lewtun <[email protected]>
1 parent efa0114 commit c9e8dda

File tree

10 files changed

+1067
-233
lines changed

10 files changed

+1067
-233
lines changed

‎docs/source/grpo_trainer.md‎

Lines changed: 97 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ At each training step, we sample a batch of prompts and generate a set of \\( G
6868

6969
### Computing the advantage
7070

71-
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
71+
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
7272

73-
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
73+
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
7474

75-
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
75+
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
7676

7777
### Estimating the KL divergence
7878

@@ -83,15 +83,15 @@ $$
8383

8484
### Computing the loss
8585

86-
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
86+
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
8787

8888
$$
8989
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
9090
$$
9191

92-
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
92+
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
9393

94-
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
94+
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
9595

9696
$$
9797
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
@@ -112,17 +112,103 @@ The GRPO Trainer logs the following metrics:
112112

113113
## Customization
114114

115-
## Speed up training with vLLM-powered generation
115+
### Speed up training with vLLM-powered generation
116116

117-
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, pass `use_vllm=True` in the training arguments.
117+
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, first install the package with
118+
119+
```shell
120+
pip install trl[vllm]
121+
```
122+
123+
Then, start the vLLM server with the desired model:
124+
125+
```bash
126+
trl vllm-serve --model <model_name>
127+
```
128+
129+
Then, pass `use_vllm=True` in the training arguments and run the training script:
118130

119131
```python
120132
from trl import GRPOConfig
121133

122134
training_args = GRPOConfig(..., use_vllm=True)
123-
```
135+
```
136+
137+
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
138+
139+
### GRPO at scale: train a 70B+ Model on multiple nodes
140+
141+
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
142+
143+
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
144+
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
145+
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
146+
147+
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
124148

125-
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
149+
```sh
150+
#!/bin/bash
151+
#SBATCH --nodes=5
152+
#SBATCH --gres=gpu:8
153+
154+
# Get the list of allocated nodes
155+
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
156+
157+
# Assign the first 4 nodes for training and the 5th node for vLLM
158+
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
159+
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
160+
161+
# Run training on the first 4 nodes (Group 1)
162+
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
163+
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
164+
--num_processes 32 \
165+
--num_machines 4 \
166+
--main_process_ip ${NODELIST[0]} \
167+
--machine_rank $SLURM_PROCID \
168+
--rdzv_backend c10d \
169+
train_grpo.py \
170+
--server_ip $VLLM_NODE &
171+
172+
# Run vLLM server on the 5th node (Group 2)
173+
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
174+
175+
wait
176+
```
177+
178+
```python
179+
import argparse
180+
181+
from datasets import load_dataset
182+
from trl import GRPOTrainer, GRPOConfig
183+
184+
def main():
185+
parser = argparse.ArgumentParser()
186+
parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
187+
args = parser.parse_args()
188+
189+
# Example dataset from TLDR
190+
dataset = load_dataset("trl-lib/tldr", split="train")
191+
192+
# Dummy reward function: count the number of unique characters in the completions
193+
def reward_num_unique_chars(completions, **kwargs):
194+
return [len(set(c)) for c in completions]
195+
196+
training_args = GRPOConfig(
197+
output_dir="Qwen2.5-72B-GRPO",
198+
per_device_train_batch_size=4,
199+
bf16=True,
200+
gradient_checkpointing=True,
201+
logging_steps=10,
202+
use_vllm=True,
203+
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
204+
)
205+
206+
trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
207+
trainer.train()
208+
209+
if __name__=="__main__":
210+
main()
211+
```
126212

127213
### Using a custom reward function
128214

@@ -247,7 +333,7 @@ def math_reward_func(prompts, completions, task, **kwargs):
247333
# Return None for non-math tasks
248334
rewards.append(None)
249335
return rewards
250-
336+
251337
# Coding-specific reward function
252338
def coding_reward_func(prompts, completions, task, **kwargs):
253339
rewards = []

‎docs/source/speeding_up_training.md‎

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,39 +37,68 @@ training_args = OnlineDPOConfig(..., use_vllm=True)
3737
</hfoption>
3838
<hfoption id="GRPO">
3939

40-
Then, enable it by passing `use_vllm=True` in the training arguments.
40+
First, start a vLLM server by running:
41+
42+
```bash
43+
trl vllm-serve --model <model_name>
44+
```
45+
46+
Then, run the training script and pass `use_vllm=True` in the training arguments.
4147

4248
```python
4349
from trl import GRPOConfig
4450

4551
training_args = GRPOConfig(..., use_vllm=True)
4652
```
4753

48-
The strategy here is to use a dedicated GPU for generation powered by vLLM, while using the remainder for training.
54+
You can customize the server configuration by passing additional arguments.
55+
56+
```sh
57+
$ trl vllm-serve --help
58+
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--host HOST]
59+
[--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE]
60+
[--max_model_len MAX_MODEL_LEN] [--enable_prefix_caching ENABLE_PREFIX_CACHING]
61+
62+
options:
63+
-h, --help Show this help message and exit
64+
--model MODEL Model name or path to load the model from. (default: None)
65+
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
66+
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
67+
Number of tensor parallel workers to use. (default: 1)
68+
--host HOST Host address to run the server on. (default: 0.0.0.0)
69+
--port PORT Port to run the server on. (default: 8000)
70+
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
71+
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device
72+
dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the
73+
model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during
74+
initialization. (default: 0.9)
75+
--dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on
76+
the model configuration. Find the supported values in the vLLM documentation. (default: auto)
77+
--max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN
78+
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
79+
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context
80+
size, which might be much larger than the KV cache, leading to inefficiencies. (default: None)
81+
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
82+
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this
83+
feature. (default: None)
84+
```
4985
5086
<Tip warning={true}>
5187
52-
When using vLLM, an additional GPU is required exclusively for generation. This means you need at least two available GPUs and must ensure that one remains unused by the trainer. To achieve this, run the training with `--num_processes <NUMBER_OF_GPUs - 1>`.
88+
When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
5389
54-
For example, if you have 4 GPUs, set `--num_processes 3` to allocate three GPUs for training while reserving one for generation.
55-
```bash
56-
accelerate launch --multi_gpu --num_processes 3 train_grpo.py
57-
```
90+
Set GPUs **0-3** for vLLM generation:
91+
```sh
92+
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
93+
```
5894
59-
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/1_gpu_for_generation.png)
95+
And GPUs **4-7** for training:
96+
```sh
97+
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
98+
```
6099
61100
</Tip>
62101
63-
You can further tune the vLLM configuration by setting a specific `vllm_device` and `vllm_gpu_memory_utilization` in the [`GRPOConfig`].
64-
65-
```python
66-
training_args = GRPOConfig(
67-
...,
68-
use_vllm=True,
69-
vllm_device="cuda:4",
70-
vllm_gpu_memory_utilization=0.7,
71-
)
72-
```
73102
74103
</hfoption>
75104
</hfoptions>

‎setup.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@
9191
"scikit": ["scikit-learn"],
9292
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
9393
# vllm is not available on Windows
94-
# vllm 0.7.3 causes hanging while gathering. temporary pinning the version until the issue is resolved
95-
"vllm": ["vllm==0.7.2; sys_platform != 'win32'"],
94+
"vllm": ["vllm>=0.7.0; sys_platform != 'win32'", "fastapi", "pydantic", "requests", "uvicorn"],
9695
"vlm": ["Pillow"],
9796
}
9897
EXTRAS["dev"] = []

‎tests/test_grpo_trainer.py‎

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from datasets import load_dataset
2121
from parameterized import parameterized
2222
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
23-
from transformers.testing_utils import require_peft, require_torch_accelerator
23+
from transformers.testing_utils import require_peft
2424
from transformers.utils import is_peft_available
2525

2626
from trl import GRPOConfig, GRPOTrainer
@@ -631,7 +631,7 @@ def reward_func(completions, some_values, **kwargs):
631631
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
632632

633633
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
634-
@require_torch_accelerator
634+
@unittest.skip("We should add a mock for the vLLM server.")
635635
def test_training_vllm(self):
636636
"""Test that training works with vLLM for generation."""
637637
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -645,8 +645,6 @@ def test_training_vllm(self):
645645
max_completion_length=32, # reduce the completion length to reduce memory usage
646646
report_to="none",
647647
use_vllm=True,
648-
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
649-
vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm
650648
)
651649
trainer = GRPOTrainer(
652650
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny is too small for vLLM
@@ -761,7 +759,7 @@ def test_beta_zero_no_ref_model_and_no_kl(self):
761759
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
762760

763761
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
764-
@require_torch_accelerator
762+
@unittest.skip("We should add a mock for the vLLM server.")
765763
@require_peft
766764
def test_training_vllm_and_peft(self):
767765
"""Test that training works with vLLM for generation."""
@@ -778,8 +776,6 @@ def test_training_vllm_and_peft(self):
778776
max_completion_length=32, # reduce the completion length to reduce memory usage
779777
report_to="none",
780778
use_vllm=True,
781-
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
782-
vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm
783779
)
784780
lora_config = LoraConfig(
785781
target_modules="all-linear",
@@ -810,7 +806,7 @@ def test_training_vllm_and_peft(self):
810806
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
811807

812808
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
813-
@require_torch_accelerator
809+
@unittest.skip("We should add a mock for the vLLM server.")
814810
def test_training_vllm_guided_decoding(self):
815811
"""Test that training works with vLLM for generation with guided decoding."""
816812
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -824,8 +820,6 @@ def test_training_vllm_guided_decoding(self):
824820
max_completion_length=32, # reduce the completion length to reduce memory usage
825821
report_to="none",
826822
use_vllm=True,
827-
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
828-
vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm
829823
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
830824
)
831825
trainer = GRPOTrainer(
@@ -883,7 +877,7 @@ def test_training_with_additional_generation_kwargs(self):
883877
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
884878

885879
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
886-
@require_torch_accelerator
880+
@unittest.skip("We should add a mock for the vLLM server.")
887881
def test_training_vllm_with_additional_generation_kwargs(self):
888882
"""Test that training works with vLLM and additional generation kwargs."""
889883
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -897,8 +891,6 @@ def test_training_vllm_with_additional_generation_kwargs(self):
897891
max_completion_length=32, # reduce the completion length to reduce memory usage
898892
report_to="none",
899893
use_vllm=True,
900-
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
901-
vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm
902894
top_p=0.9,
903895
top_k=10,
904896
min_p=0.01,

‎trl/cli.py‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from .scripts.kto import make_parser as make_kto_parser
2626
from .scripts.sft import make_parser as make_sft_parser
2727
from .scripts.utils import TrlParser
28+
from .scripts.vllm_serve import main as vllm_serve_main
29+
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser
2830

2931

3032
def main():
@@ -40,6 +42,7 @@ def main():
4042
make_grpo_parser(subparsers)
4143
make_kto_parser(subparsers)
4244
make_sft_parser(subparsers)
45+
make_vllm_serve_parser(subparsers)
4346

4447
# Parse the arguments
4548
args = parser.parse_args()
@@ -87,6 +90,10 @@ def main():
8790
args.training_script_args = sys.argv[2:] # remove "trl" and "sft"
8891
launch_command(args) # launch training
8992

93+
elif args.command == "vllm-serve":
94+
(script_args,) = parser.parse_args_and_config()
95+
vllm_serve_main(script_args)
96+
9097

9198
if __name__ == "__main__":
9299
main()

0 commit comments

Comments
 (0)