Skip to content

Commit da2d521

Browse files
jhinpanPrinsYin
authored andcommitted
Feature: Add SGLang support for GRPO Trainer
Turn to the online server API Usage add test and fix bugs in result parsing Pass First test with fixing _update_sglang_weights Remove checkpoints from tracking and add to .gitignore config to run on single gpu successfully Update code to align with vllm save model and update weight save model only main process A runnable update_from_tensor version fix performance issue resolve comment: help strings resolve comment: help strings Update trl/trainer/grpo_config.py Update trl/trainer/grpo_config.py Update trl/trainer/grpo_config.py Update trl/trainer/grpo_config.py Update trl/trainer/grpo_config.py Update trl/trainer/grpo_config.py Update trl/trainer/grpo_trainer.py call raise_for_status remove duplicate doc string formatting add sglang to extras formatting import requests only when sglang is available formatting undo formatting undo formatting more undo last one! add initial docs add sglang last one now new line delete test scripts Update setup.cfg Update setup.cfg intiial sglang-serve cli script Update trl/trainer/grpo_trainer.py remove dead code Co-authored-by: Kashif Rasul <[email protected]> debug GRPO trainer change num_processes update how to run sglang
1 parent 559a99f commit da2d521

File tree

8 files changed

+498
-5
lines changed

8 files changed

+498
-5
lines changed

docs/source/grpo_trainer.md

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,66 @@ If the recommended value does not work in your environment, we suggest adding a
243243

244244
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
245245

246+
247+
### Speed up training with SGLang-powered generation
248+
249+
Another alternative to vLLM is to use the [SGLang](https://sglang.ai/) to enable fast generate. To enable it first install the package with:
250+
251+
```shell
252+
git clone [email protected]:huggingface/trl.git
253+
cd trl
254+
python3 -m uv pip install -e ".[sglang]"
255+
256+
# start sglang-server
257+
python3 -m sglang.launch_server --model-path qwen/qwen2.5-7b-instruct
258+
259+
# run "export CUDA_VISIBLE_DEVICES"
260+
# run script
261+
python3 grpo_test.py
262+
```
263+
264+
Then, pass the `use_sglang=True` in the training arguments and point to the SGLang server via the `sglang_server_url`:
265+
266+
```python
267+
import os
268+
269+
from datasets import load_dataset
270+
271+
from trl import GRPOConfig, GRPOTrainer
272+
273+
274+
dataset = load_dataset("trl-lib/tldr", split="train[:10%]”)
275+
276+
checkpoint_dir = os.path.join("/sgl-workspace/ryang/trl", "checkpoints/sgl")
277+
os.makedirs(checkpoint_dir, exist_ok=True)
278+
279+
def reward_len(completions, **kwargs):
280+
return [-abs(20 - len(completion)) for completion in completions]
281+
282+
283+
training_args = GRPOConfig(
284+
output_dir=os.path.join(checkpoint_dir, "Qwen2.5_output"),
285+
logging_steps=10,
286+
use_sglang=True,
287+
sglang_device="cuda:0",
288+
sglang_gpu_memory_utilization=0.9,
289+
sglang_server_url="http://127.0.0.1:30000",
290+
)
291+
292+
293+
trainer = GRPOTrainer(
294+
model="Qwen/Qwen2.5-7B-Instruct",
295+
reward_funcs=reward_len,
296+
args=training_args,
297+
train_dataset=dataset,
298+
)
299+
300+
training_args.checkpoint_path = checkpoint_dir
301+
302+
303+
trainer.train()
304+
```
305+
246306
### GRPO at scale: train a 70B+ Model on multiple nodes
247307

248308
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ vllm =
7474
requests; python_version < "3.13"
7575
uvicorn; python_version < "3.13"
7676

77+
sglang =
78+
sglang>=0.4.6post2
79+
requests
80+
7781
vlm =
7882
Pillow
7983
dev =
@@ -87,6 +91,7 @@ dev =
8791
%(scikit)s
8892
%(test)s
8993
%(vlm)s
94+
%(sglang)s
9095

9196
[options.entry_points]
9297
console_scripts =

trl/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from .scripts.grpo import make_parser as make_grpo_parser
2525
from .scripts.kto import make_parser as make_kto_parser
2626
from .scripts.sft import make_parser as make_sft_parser
27+
from .scripts.sglang_serve import main as sglang_serve_main
28+
from .scripts.sglang_serve import make_parser as make_sglang_serve_parser
2729
from .scripts.utils import TrlParser
2830
from .scripts.vllm_serve import main as vllm_serve_main
2931
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser
@@ -42,6 +44,7 @@ def main():
4244
make_kto_parser(subparsers)
4345
make_sft_parser(subparsers)
4446
make_vllm_serve_parser(subparsers)
47+
make_sglang_serve_parser(subparsers)
4548

4649
# Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser.
4750
# Duplicates may occur if the same argument is provided in both the config file and CLI.
@@ -131,6 +134,9 @@ def main():
131134
)
132135

133136
vllm_serve_main(script_args)
137+
elif args.command == "sglang-serve":
138+
(script_args,) = parser.parse_args_and_config()
139+
sglang_serve_main(script_args)
134140

135141

136142
if __name__ == "__main__":

trl/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
_vllm_available = _is_package_available("vllm")
3939
_vllm_ascend_available = _is_package_available("vllm_ascend")
4040
_joblib_available = _is_package_available("joblib")
41+
_sglang_available = _is_package_available("sglang")
42+
4143

4244

4345
def is_deepspeed_available() -> bool:
@@ -91,6 +93,9 @@ def is_vllm_ascend_available() -> bool:
9193
def is_joblib_available() -> bool:
9294
return _joblib_available
9395

96+
def is_sglang_available() -> bool:
97+
return _sglang_available
98+
9499

95100
class _LazyModule(ModuleType):
96101
"""

trl/scripts/env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
is_diffusers_available,
2828
is_liger_kernel_available,
2929
is_llm_blender_available,
30+
is_sglang_available,
3031
is_vllm_available,
3132
)
3233
from .utils import get_git_commit_hash
@@ -74,6 +75,7 @@ def print_env():
7475
"OpenAI version": version("openai") if is_openai_available() else "not installed",
7576
"PEFT version": version("peft") if is_peft_available() else "not installed",
7677
"vLLM version": version("vllm") if is_vllm_available() else "not installed",
78+
"SGLang version": version("sglang") if is_sglang_available() else "not installed",
7779
}
7880

7981
info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()])

0 commit comments

Comments
 (0)