Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
51ef1b1
Feature: Add SGLang support for GRPO Trainer
jhinpan Feb 18, 2025
bcbed19
Turn to the online server API Usage
jhinpan Feb 19, 2025
b1b92fc
add test and fix bugs in result parsing
Jayon02 Feb 19, 2025
ed115af
Pass First test with fixing _update_sglang_weights
jhinpan Feb 20, 2025
941db60
Remove checkpoints from tracking and add to .gitignore
jhinpan Feb 21, 2025
e622ba9
config to run on single gpu successfully
ryang-max Apr 22, 2025
9fba5f0
Merge branch 'main' into sglang-server
ryang-max Apr 23, 2025
7de7ddb
Update code to align with vllm
ryang-max Apr 23, 2025
029402e
Merge remote-tracking branch 'origin/main' into sglang-server
ryang-max Apr 23, 2025
8debe2a
save model and update weight
ryang-max Apr 23, 2025
26d34c3
save model only main process
ryang-max Apr 24, 2025
69ebec8
A runnable update_from_tensor version
ryang-max Apr 24, 2025
0fcdd83
fix performance issue
ryang-max Apr 27, 2025
35e05f0
Merge branch 'main' into sglang-server
ryang-max Apr 27, 2025
8d75a8f
resolve comment: help strings
renxinx May 1, 2025
ddf67e9
resolve comment: help strings
renxinx May 1, 2025
6745e6b
Update trl/trainer/grpo_config.py
kashif May 2, 2025
6887ed5
Update trl/trainer/grpo_config.py
kashif May 2, 2025
4e020b4
Update trl/trainer/grpo_config.py
kashif May 2, 2025
5787bfc
Update trl/trainer/grpo_config.py
kashif May 2, 2025
4f8021a
Update trl/trainer/grpo_config.py
kashif May 2, 2025
f73e652
Update trl/trainer/grpo_config.py
kashif May 2, 2025
62dc22e
Update trl/trainer/grpo_trainer.py
kashif May 2, 2025
3a95d13
call raise_for_status
kashif May 2, 2025
f733428
remove duplicate
kashif May 2, 2025
e91e7d8
doc string
kashif May 2, 2025
1f2fada
formatting
kashif May 2, 2025
88ad6af
add sglang to extras
kashif May 2, 2025
9a2db24
formatting
kashif May 2, 2025
e139430
import requests only when sglang is available
kashif May 2, 2025
4693aa0
formatting
kashif May 2, 2025
cf2e1ff
undo formatting
kashif May 2, 2025
0a079cc
undo formatting
kashif May 2, 2025
9ddf9a0
more undo
kashif May 2, 2025
45214e9
last one!
kashif May 2, 2025
ccbf97b
add initial docs
kashif May 2, 2025
fe94157
Merge branch 'main' into sglang-server
kashif May 2, 2025
10af891
add sglang
kashif May 2, 2025
829ae41
last one now
kashif May 2, 2025
f48c7e6
new line
kashif May 2, 2025
2bcf24c
Merge branch 'main' into sglang-server
kashif May 4, 2025
a6158fa
delete test scripts
renxinx May 7, 2025
6380ce5
Merge branch 'main' into sglang-server
renxinx May 7, 2025
85d1906
Merge branch 'main' into sglang-server
kashif May 9, 2025
865afb4
Update setup.cfg
kashif May 9, 2025
6e94e53
Update setup.cfg
kashif May 9, 2025
8e3697d
intiial sglang-serve cli script
kashif May 13, 2025
e7149a0
Update trl/trainer/grpo_trainer.py
ryang-max May 21, 2025
94c1c9b
debug GRPO trainer
renxinx May 21, 2025
a665a17
change num_processes
renxinx May 22, 2025
9e634d1
update how to run sglang
renxinx May 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,67 @@ Depending on the model size and the overall GPU memory requirements for training

For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).


### Speed up training with SGLang-powered generation

Another alternative to vLLM is to use the [SGLang](https://sglang.ai/) to enable fast generate. To enable it first install the package with:

```shell
git clone [email protected]:huggingface/trl.git
cd trl
python3 -m uv pip install -e ".[sglang]"

# start sglang-server
python3 -m sglang.launch_server --model-path qwen/qwen2.5-7b-instruct

# run "export CUDA_VISIBLE_DEVICES"
# run script
python3 grpo_test.py
```

Then, pass the `use_sglang=True` in the training arguments and point to the SGLang server via the `sglang_server_url`:

```python
import os

from datasets import load_dataset

from trl import GRPOConfig, GRPOTrainer


dataset = load_dataset("trl-lib/tldr", split="train[:10%]”)

checkpoint_dir = os.path.join("/sgl-workspace/ryang/trl", "checkpoints/sgl")
os.makedirs(checkpoint_dir, exist_ok=True)

def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]


training_args = GRPOConfig(
output_dir=os.path.join(checkpoint_dir, "Qwen2.5_output"),
logging_steps=10,
use_sglang=True,
sglang_device="cuda:0",
sglang_gpu_memory_utilization=0.9,
sglang_server_url="http://127.0.0.1:30000",
)


trainer = GRPOTrainer(
model="Qwen/Qwen2.5-7B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)

training_args.checkpoint_path = checkpoint_dir


trainer.train()
```


### GRPO at scale: train a 70B+ Model on multiple nodes

When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ vllm =
pydantic; python_version < "3.13"
requests; python_version < "3.13"
uvicorn; python_version < "3.13"
sglang =
sglang>=0.4.6post2
requests

vlm =
Pillow
Expand All @@ -91,6 +94,7 @@ dev =
%(test)s
%(vllm)s
%(vlm)s
%(sglang)s

[options.entry_points]
console_scripts =
Expand Down
7 changes: 7 additions & 0 deletions trl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from .scripts.grpo import make_parser as make_grpo_parser
from .scripts.kto import make_parser as make_kto_parser
from .scripts.sft import make_parser as make_sft_parser
from .scripts.sglang_serve import main as sglang_serve_main
from .scripts.sglang_serve import make_parser as make_sglang_serve_parser
from .scripts.utils import TrlParser
from .scripts.vllm_serve import main as vllm_serve_main
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser
Expand All @@ -45,6 +47,7 @@ def main():
make_kto_parser(subparsers)
make_sft_parser(subparsers)
make_vllm_serve_parser(subparsers)
make_sglang_serve_parser(subparsers)

# Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser.
# Duplicates may occur if the same argument is provided in both the config file and CLI.
Expand Down Expand Up @@ -139,6 +142,10 @@ def main():

vllm_serve_main(script_args)

elif args.command == "sglang-serve":
(script_args,) = parser.parse_args_and_config()
sglang_serve_main(script_args)


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_vllm_available = _is_package_available("vllm")
_vllm_ascend_available = _is_package_available("vllm_ascend")
_joblib_available = _is_package_available("joblib")
_sglang_available = _is_package_available("sglang")


def is_deepspeed_available() -> bool:
Expand Down Expand Up @@ -92,6 +93,10 @@ def is_joblib_available() -> bool:
return _joblib_available


def is_sglang_available() -> bool:
return _sglang_available


class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
Expand Down
2 changes: 2 additions & 0 deletions trl/scripts/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_diffusers_available,
is_liger_kernel_available,
is_llm_blender_available,
is_sglang_available,
is_vllm_available,
)
from .utils import get_git_commit_hash
Expand Down Expand Up @@ -69,6 +70,7 @@ def print_env():
"OpenAI version": version("openai") if is_openai_available() else "not installed",
"PEFT version": version("peft") if is_peft_available() else "not installed",
"vLLM version": version("vllm") if is_vllm_available() else "not installed",
"SGLang version": version("sglang") if is_sglang_available() else "not installed",
}

info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()])
Expand Down
Loading