Skip to content
Closed
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
51ef1b1
Feature: Add SGLang support for GRPO Trainer
jhinpan Feb 18, 2025
bcbed19
Turn to the online server API Usage
jhinpan Feb 19, 2025
b1b92fc
add test and fix bugs in result parsing
Jayon02 Feb 19, 2025
ed115af
Pass First test with fixing _update_sglang_weights
jhinpan Feb 20, 2025
941db60
Remove checkpoints from tracking and add to .gitignore
jhinpan Feb 21, 2025
e622ba9
config to run on single gpu successfully
ryang-max Apr 22, 2025
9fba5f0
Merge branch 'main' into sglang-server
ryang-max Apr 23, 2025
7de7ddb
Update code to align with vllm
ryang-max Apr 23, 2025
029402e
Merge remote-tracking branch 'origin/main' into sglang-server
ryang-max Apr 23, 2025
8debe2a
save model and update weight
ryang-max Apr 23, 2025
26d34c3
save model only main process
ryang-max Apr 24, 2025
69ebec8
A runnable update_from_tensor version
ryang-max Apr 24, 2025
0fcdd83
fix performance issue
ryang-max Apr 27, 2025
35e05f0
Merge branch 'main' into sglang-server
ryang-max Apr 27, 2025
8d75a8f
resolve comment: help strings
renxinx May 1, 2025
ddf67e9
resolve comment: help strings
renxinx May 1, 2025
6745e6b
Update trl/trainer/grpo_config.py
kashif May 2, 2025
6887ed5
Update trl/trainer/grpo_config.py
kashif May 2, 2025
4e020b4
Update trl/trainer/grpo_config.py
kashif May 2, 2025
5787bfc
Update trl/trainer/grpo_config.py
kashif May 2, 2025
4f8021a
Update trl/trainer/grpo_config.py
kashif May 2, 2025
f73e652
Update trl/trainer/grpo_config.py
kashif May 2, 2025
62dc22e
Update trl/trainer/grpo_trainer.py
kashif May 2, 2025
3a95d13
call raise_for_status
kashif May 2, 2025
f733428
remove duplicate
kashif May 2, 2025
e91e7d8
doc string
kashif May 2, 2025
1f2fada
formatting
kashif May 2, 2025
88ad6af
add sglang to extras
kashif May 2, 2025
9a2db24
formatting
kashif May 2, 2025
e139430
import requests only when sglang is available
kashif May 2, 2025
4693aa0
formatting
kashif May 2, 2025
cf2e1ff
undo formatting
kashif May 2, 2025
0a079cc
undo formatting
kashif May 2, 2025
9ddf9a0
more undo
kashif May 2, 2025
45214e9
last one!
kashif May 2, 2025
ccbf97b
add initial docs
kashif May 2, 2025
fe94157
Merge branch 'main' into sglang-server
kashif May 2, 2025
10af891
add sglang
kashif May 2, 2025
829ae41
last one now
kashif May 2, 2025
f48c7e6
new line
kashif May 2, 2025
2bcf24c
Merge branch 'main' into sglang-server
kashif May 4, 2025
a6158fa
delete test scripts
renxinx May 7, 2025
6380ce5
Merge branch 'main' into sglang-server
renxinx May 7, 2025
85d1906
Merge branch 'main' into sglang-server
kashif May 9, 2025
865afb4
Update setup.cfg
kashif May 9, 2025
6e94e53
Update setup.cfg
kashif May 9, 2025
8e3697d
intiial sglang-serve cli script
kashif May 13, 2025
e7149a0
Update trl/trainer/grpo_trainer.py
ryang-max May 21, 2025
94c1c9b
debug GRPO trainer
renxinx May 21, 2025
a665a17
change num_processes
renxinx May 22, 2025
9e634d1
update how to run sglang
renxinx May 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,24 @@ Depending on the model size and the overall GPU memory requirements for training

For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).


### Speed up training with SGLang-powered generation

Another alternative to vLLM is to use the [SGLang](https://sglang.ai/) to enable fast generate. To enable it first install the package with:

```shell
pip install trl[sglang]
```

Then, pass the `use_sglang=True` in the training arguments and point to the SGLang server via the `sglang_server_url`:

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_sglang=True, sglang_server_url="http://127.0.0.1:30000")
```


### GRPO at scale: train a 70B+ Model on multiple nodes

When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ vllm =
pydantic
requests
uvicorn
sglang =
sglang>=0.4.6post2
requests
vlm =
Pillow
dev =
Expand All @@ -85,6 +88,7 @@ dev =
%(test)s
%(vllm)s
%(vlm)s
%(sglang)s

[options.entry_points]
console_scripts =
Expand Down
5 changes: 5 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_vllm_available = _is_package_available("vllm")
_vllm_ascend_available = _is_package_available("vllm_ascend")
_joblib_available = _is_package_available("joblib")
_sglang_available = _is_package_available("sglang")


def is_deepspeed_available() -> bool:
Expand Down Expand Up @@ -92,6 +93,10 @@ def is_joblib_available() -> bool:
return _joblib_available


def is_sglang_available() -> bool:
return _sglang_available


class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
Expand Down
56 changes: 56 additions & 0 deletions trl/scripts/grpo_test/grpo_sgl_test.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the final version, the files in scripts should be removed

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from datasets import load_dataset

from trl import GRPOConfig, GRPOTrainer



dataset = load_dataset("trl-lib/tldr", split="train[:1%]")

checkpoint_dir = os.path.join("/sgl-workspace/ryang/trl", "checkpoints/sgl")
os.makedirs(checkpoint_dir, exist_ok=True)


# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]


training_args = GRPOConfig(
output_dir=os.path.join(checkpoint_dir, "Qwen2.5_output"),
logging_steps=10,
# report_to="wandb",
# use_vllm=True,
use_sglang=True,
sglang_device="cuda:1",
sglang_gpu_memory_utilization=0.9,
sglang_server_url="http://127.0.0.1:30000",
)


trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)

training_args.checkpoint_path = checkpoint_dir # Set the checkpoint path for later use


trainer.train()
17 changes: 17 additions & 0 deletions trl/scripts/grpo_test/grpo_sgl_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 3
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 29600
4 changes: 4 additions & 0 deletions trl/scripts/grpo_test/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=5,6,7
export PYTHONPATH="/sgl-workspace/ryang/trl:$PYTHONPATH"
accelerate launch --config_file=trl/scripts/grpo_test/grpo_sgl_test.yaml trl/scripts/grpo_test/grpo_sgl_test.py
38 changes: 38 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ class GRPOConfig(TrainingArguments):
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
launching the vLLM server via the `--vllm_tensor_parallel_size` flag.

> Parameters that control generation acceleration powered by SGLang

use_sglang (`bool`, *optional*, defaults to `False`):
Whether to use SGLang for generating completions. If set to `True`, a SGLang server must be running.
sglang_server_url (`str` or `None`, *optional*, defaults to `None`):
The URL of the SGLang server (e.g. "http://localhost:30033"). Required if `use_sglang` is `True`.
sglang_device (`str`, *optional*, defaults to `"cuda:1"`):
GPU device to be used for SGLang generation if launching from this code. This is optional if the server is
managed externally.

> Parameters that control the training

learning_rate (`float`, *optional*, defaults to `1e-6`):
Expand Down Expand Up @@ -348,6 +358,34 @@ class GRPOConfig(TrainingArguments):
},
)

# When running the trainer, set the following command-line arguments (or JSON configuration) so that SGLang is used:
# • --use_sglang True
# • --sglang_server_url "http://localhost:30033"
# • Optionally, --sglang_device "cuda:1" if you wish to assign a specific GPU.
# Parameters for generation acceleration powered by SGLang
use_sglang: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use SGLang for generating completions. If `True`, a SGLang server must be running."
},
)
sglang_server_url: Optional[str] = field(
default="http://localhost:32232",
metadata={
"help": "The URL of the SGLang server (e.g., 'http://localhost:32232'). Required if `use_sglang` is `True`."
},
)
sglang_device: Optional[str] = field(
default="auto",
metadata={
"help": "The GPU device to be used for SGLang generation if launching internally. Optional if the server is managed externally."
},
)
sglang_gpu_memory_utilization: float = field(
default=0.9,
metadata={"help": "Ratio of GPU memory reserved for SGLang generation."},
)

# Parameters that control the training
learning_rate: float = field(
default=1e-6,
Expand Down
115 changes: 111 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
from ..import_utils import is_liger_kernel_available, is_sglang_available, is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
from ..models.utils import _ForwardRedirection
from .callbacks import SyncRefModelCallback
Expand Down Expand Up @@ -76,6 +76,10 @@
if is_wandb_available():
import wandb

if is_sglang_available():
import requests
from sglang.srt.utils import MultiprocessingSerializer

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
Expand Down Expand Up @@ -475,6 +479,7 @@ def data_collator(features): # No data collation is needed in GRPO
self.min_p = args.min_p
self.repetition_penalty = args.repetition_penalty
self.use_vllm = args.use_vllm
self.use_sglang = args.use_sglang
self.vllm_mode = args.vllm_mode
self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
Expand Down Expand Up @@ -618,7 +623,32 @@ def data_collator(features): # No data collation is needed in GRPO
# it's safer to set it in all cases.
set_seed(args.seed, device_specific=True)

if self.use_vllm:
# Initialization for the inference backend
if self.use_sglang:
if not is_sglang_available():
raise ImportError(
"SGLang is not available and `use_sglang` is set to True. Please install SGLang with "
"`pip install sglang` to use it."
)
# Use externally managed SGLang server.
# The server URL is provided via configuration, e.g., "http://localhost:32232"
if not args.sglang_server_url:
raise ValueError("SGLang is enabled but no server URL was provided (use --sglang_server_url).")
self.sglang_server_url = args.sglang_server_url
if self.accelerator.is_main_process:
self.sglang_sampling_params = {
"temperature": self.temperature,
"max_new_tokens": self.max_completion_length,
"n": self.num_generations,
"repetition_penalty": self.repetition_penalty,
"top_p": self.top_p,
"top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
}

self._last_loaded_step = -1
self.accelerator.wait_for_everyone()
elif self.use_vllm:
if not is_vllm_available():
raise ImportError(
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
Expand Down Expand Up @@ -893,6 +923,35 @@ def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(full_name, param.data)])

@profiling_decorator
def _update_sglang_weights(self):
"""
Update the model weights on the SGLang server via its tensor-based update API.
This function only be called in main_process.
"""
payload = {
"serialized_named_tensors": [
MultiprocessingSerializer.serialize(list(self.model.named_parameters()), output_str=True)
],
"flush_cache": True, # flush cache after update weights
}
try:
response = requests.post(
f"{self.sglang_server_url}/update_weights_from_tensor",
json=payload,
timeout=60,
)
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
except requests.RequestException as e:
# Keep the original exception context
raise RuntimeError(f"Weight update request failed: {e}") from e
res_json = response.json()
if not res_json.get("success", False):
# No underlying exception to chain from here, as it's a logic error from the server response
raise RuntimeError(
f"Failed to update weights on SGLang server: {res_json.get('message', 'No message provided')}"
)

@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
Expand Down Expand Up @@ -1006,8 +1065,56 @@ def _generate_and_score_completions(
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]

# Generate completions using either vLLM or regular generation
if self.use_vllm:
# Generation branch: choose SGLang, vLLM, or default generation.
if self.use_sglang:
# Update weights if the training step has advanced.
if self.state.global_step != self._last_loaded_step:
if self.accelerator.is_main_process:
self._update_sglang_weights()
self._last_loaded_step = self.state.global_step

# Gather all prompt texts from all processes.
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
payload = {
"text": ordered_set_of_prompts,
"sampling_params": self.sglang_sampling_params,
}
response = requests.post(f"{self.sglang_server_url}/generate", json=payload)
generated_texts = [item.get("text") for item in response.json()]
completion_ids = [self.processing_class.encode(text) for text in generated_texts]
else:
completion_ids = [None] * len(all_prompts_text)
# # Broadcast and slice the generated completions.
# completion_ids = broadcast_object_list(completion_ids, from_process=0)
# process_slice = slice(
# self.accelerator.process_index * len(prompts),
# (self.accelerator.process_index + 1) * len(prompts),
# )
# completion_ids = completion_ids[process_slice]
# completion_ids = [
# torch.tensor(ids, device=device) for ids in completion_ids
# ]
# completion_ids = pad(
# completion_ids, padding_value=self.processing_class.pad_token_id
# )
# prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]

# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
elif self.use_vllm:
# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
Expand Down