Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
51ef1b1
Feature: Add SGLang support for GRPO Trainer
jhinpan Feb 18, 2025
bcbed19
Turn to the online server API Usage
jhinpan Feb 19, 2025
b1b92fc
add test and fix bugs in result parsing
Jayon02 Feb 19, 2025
ed115af
Pass First test with fixing _update_sglang_weights
jhinpan Feb 20, 2025
941db60
Remove checkpoints from tracking and add to .gitignore
jhinpan Feb 21, 2025
e622ba9
config to run on single gpu successfully
ryang-max Apr 22, 2025
9fba5f0
Merge branch 'main' into sglang-server
ryang-max Apr 23, 2025
7de7ddb
Update code to align with vllm
ryang-max Apr 23, 2025
029402e
Merge remote-tracking branch 'origin/main' into sglang-server
ryang-max Apr 23, 2025
8debe2a
save model and update weight
ryang-max Apr 23, 2025
26d34c3
save model only main process
ryang-max Apr 24, 2025
69ebec8
A runnable update_from_tensor version
ryang-max Apr 24, 2025
0fcdd83
fix performance issue
ryang-max Apr 27, 2025
35e05f0
Merge branch 'main' into sglang-server
ryang-max Apr 27, 2025
8d75a8f
resolve comment: help strings
renxinx May 1, 2025
ddf67e9
resolve comment: help strings
renxinx May 1, 2025
6745e6b
Update trl/trainer/grpo_config.py
kashif May 2, 2025
6887ed5
Update trl/trainer/grpo_config.py
kashif May 2, 2025
4e020b4
Update trl/trainer/grpo_config.py
kashif May 2, 2025
5787bfc
Update trl/trainer/grpo_config.py
kashif May 2, 2025
4f8021a
Update trl/trainer/grpo_config.py
kashif May 2, 2025
f73e652
Update trl/trainer/grpo_config.py
kashif May 2, 2025
62dc22e
Update trl/trainer/grpo_trainer.py
kashif May 2, 2025
3a95d13
call raise_for_status
kashif May 2, 2025
f733428
remove duplicate
kashif May 2, 2025
e91e7d8
doc string
kashif May 2, 2025
1f2fada
formatting
kashif May 2, 2025
88ad6af
add sglang to extras
kashif May 2, 2025
9a2db24
formatting
kashif May 2, 2025
e139430
import requests only when sglang is available
kashif May 2, 2025
4693aa0
formatting
kashif May 2, 2025
cf2e1ff
undo formatting
kashif May 2, 2025
0a079cc
undo formatting
kashif May 2, 2025
9ddf9a0
more undo
kashif May 2, 2025
45214e9
last one!
kashif May 2, 2025
ccbf97b
add initial docs
kashif May 2, 2025
fe94157
Merge branch 'main' into sglang-server
kashif May 2, 2025
10af891
add sglang
kashif May 2, 2025
829ae41
last one now
kashif May 2, 2025
f48c7e6
new line
kashif May 2, 2025
2bcf24c
Merge branch 'main' into sglang-server
kashif May 4, 2025
a6158fa
delete test scripts
renxinx May 7, 2025
6380ce5
Merge branch 'main' into sglang-server
renxinx May 7, 2025
85d1906
Merge branch 'main' into sglang-server
kashif May 9, 2025
865afb4
Update setup.cfg
kashif May 9, 2025
6e94e53
Update setup.cfg
kashif May 9, 2025
8e3697d
intiial sglang-serve cli script
kashif May 13, 2025
e7149a0
Update trl/trainer/grpo_trainer.py
ryang-max May 21, 2025
94c1c9b
debug GRPO trainer
renxinx May 21, 2025
a665a17
change num_processes
renxinx May 22, 2025
9e634d1
update how to run sglang
renxinx May 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
_vllm_available = _is_package_available("vllm")
_vllm_ascend_available = _is_package_available("vllm_ascend")
_joblib_available = _is_package_available("joblib")
_sglang_available = _is_package_available("sglang")


def is_deepspeed_available() -> bool:
Expand Down Expand Up @@ -97,6 +98,10 @@ def is_joblib_available() -> bool:
return _joblib_available


def is_sglang_available() -> bool:
return _sglang_available


class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
Expand Down
56 changes: 56 additions & 0 deletions trl/scripts/grpo_test/grpo_sgl_test.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the final version, the files in scripts should be removed

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from datasets import load_dataset

from trl import GRPOConfig, GRPOTrainer



dataset = load_dataset("trl-lib/tldr", split="train[:1%]")

checkpoint_dir = os.path.join("/sgl-workspace/ryang/trl", "checkpoints/sgl")
os.makedirs(checkpoint_dir, exist_ok=True)


# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]


training_args = GRPOConfig(
output_dir=os.path.join(checkpoint_dir, "Qwen2.5_output"),
logging_steps=10,
# report_to="wandb",
# use_vllm=True,
use_sglang=True,
sglang_device="cuda:1",
sglang_gpu_memory_utilization=0.9,
sglang_server_url="http://127.0.0.1:30000",
)


trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)

training_args.checkpoint_path = checkpoint_dir # Set the checkpoint path for later use


trainer.train()
17 changes: 17 additions & 0 deletions trl/scripts/grpo_test/grpo_sgl_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 3
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 29600
4 changes: 4 additions & 0 deletions trl/scripts/grpo_test/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=5,6,7
export PYTHONPATH="/sgl-workspace/ryang/trl:$PYTHONPATH"
accelerate launch --config_file=trl/scripts/grpo_test/grpo_sgl_test.yaml trl/scripts/grpo_test/grpo_sgl_test.py
88 changes: 86 additions & 2 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,24 @@ class GRPOConfig(TrainingArguments):
timeout, a `ConnectionError` is raised.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
Regex for vLLM guided decoding. If `None`, guided decoding is disabled.

> Parameters that control generation acceleration powered by SGLang

use_sglang (`bool`, *optional*, defaults to `False`):
Whether to use SGLang for generating completions. If set to `True`, a SGLang server must be running.
sglang_server_url (`str` or `None`, *optional*, defaults to `None`):
The URL of the SGLang server (e.g. "http://localhost:30033"). Required if `use_sglang` is `True`.
sglang_device (`str`, *optional*, defaults to `"cuda:1"`):
GPU device to be used for SGLang generation if launching from this code. This is optional if the server is
managed externally.

> Parameters that control the training

learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
Initial learning rate.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
speed, but may be numerically unstable for long training runs.
Expand Down Expand Up @@ -187,13 +199,20 @@ class GRPOConfig(TrainingArguments):
"it prevents the model from generating different logprobs for the same input."
},
)
disable_dropout: bool = field(
default=False,
metadata={
"help": "Whether to disable dropout in the model. This is useful for training with a reference model, as "
"it prevents the model from generating different logprobs for the same input."
},
)

# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
# additional columns to compute the reward
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={
metadata={
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
},
Expand All @@ -211,6 +230,10 @@ class GRPOConfig(TrainingArguments):
"* gradient_accumulation_steps) must be evenly divisible by this value."
},
)
temperature: Optional[float] = field(
default=0.9,
metadata={"help": "Temperature for sampling completions."},
)
max_completion_length: Optional[int] = field(
default=256,
metadata={"help": "Maximum length of the generated completion."},
Expand Down Expand Up @@ -296,6 +319,34 @@ class GRPOConfig(TrainingArguments):
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)

# When running the trainer, set the following command-line arguments (or JSON configuration) so that SGLang is used:
# • --use_sglang True
# • --sglang_server_url "http://localhost:30033"
# • Optionally, --sglang_device "cuda:1" if you wish to assign a specific GPU.
# Parameters for generation acceleration powered by SGLang
use_sglang: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use SGLang for generating completions. If True, a SGLang server must be running."
},
)
sglang_server_url: Optional[str] = field(
default="http://localhost:32232",
metadata={
"help": "The URL of the SGLang server (e.g., 'http://localhost:32232'). Required if use_sglang is True."
},
)
sglang_device: Optional[str] = field(
default="auto",
metadata={
"help": "The GPU device to be used for SGLang generation if launching internally. Optional if the server is managed externally."
},
)
sglang_gpu_memory_utilization: float = field(
default=0.9,
metadata={"help": "Ratio of GPU memory reserved for sglang generation."},
)

# Parameters that control the training
learning_rate: float = field(
default=1e-6,
Expand Down Expand Up @@ -366,6 +417,39 @@ class GRPOConfig(TrainingArguments):
"a good practice for training stability."
},
)
scale_rewards: bool = field(
default=True,
metadata={
"help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), "
"the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no "
"scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard "
"deviation introduces a question-level difficulty bias."
},
)
loss_type: str = field(
default="bnpo",
metadata={
"help": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`. "
"`'grpo'`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to "
"length bias—this approach tends to prefer shorter completions with positive advantages and longer ones "
"with negative advantages. "
"`'bnpo'`: Aggregates token-level losses by normalizing number of active token in the local batch. "
"Note that normalization is performed over the local batch only, so results may slightly vary depending "
"on the local batch size, despite a constant effective batch size. When using "
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. "
"`'dr_grpo'`: Aggregates token-level losses by normalizing with a global constant. This method was "
"introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to "
"`max_completion_length`."
},
)
mask_truncated_completions: bool = field(
default=False,
metadata={
"help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from "
"being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is "
"a good practice for training stability."
},
)
sync_ref_model: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -411,4 +495,4 @@ class GRPOConfig(TrainingArguments):
"help": "Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, "
"all prompts are logged."
},
)
)
Loading