Skip to content

Commit 8e9bc73

Browse files
zhaochenyang20zhaochen20PopSoda2002ChangyiYangPrinsYin
committed
[rollout] feat: support over sampling rollout in SGLang Rollout (volcengine#2929)
This PR introduces an **over-sample strategy** for verl's SGLang multi-turn rollout to address the long-tail problem, where a few slow requests disproportionately increase the overall rollout time. The core idea is to over-sample the number of requests at the start of the rollout and then aggressively cancel any requests that haven't finished once a target number of completions is met. - **Improves rollout efficiency** for multi-turn conversations by reducing total time spent waiting for slow requests. - **Implements a new request monitoring and cancellation mechanism** to cut off unnecessary computation. wandb results is as follow: https://wandb.ai/zhaochenyang20/benchmark_over_sample_2/workspace?nw=nwuserzhaochenyang20 ----- Of course, this strategy has its share of issues. For example, many might question why the over-long requests that are dropped aren't simply saved and continued in the next round. This is certainly possible—it's a partial rollout strategy—but it would require verl to have a data buffer, which is beyond the scope of this PR. Furthermore, saving and continuing these requests would introduce an off-policy problem. There is also a valid concern that this rather "brutal" dropping strategy could unfairly ignore very long requests. I agree this is a very reasonable point, but currently, we don't have a lossless solution. However, our dropping strategy is very flexible and could even change with our curriculum learning. For instance, in the example I gave, I just directly dropped the last 20% of requests. **In practice, we can dynamically adjust this drop rate and even set different dropping methods. For example, we could record the return time (t) for the 80% of requests and then drop any requests that haven't returned after 1.5t.** We've provided an initial, validated idea and have completed its implementation. We welcome everyone to join the discussion on how to accelerate multi-turn rollouts with acceptable losses. The new over-sample strategy was tested with an 8-GPU setup on the **gsm8k** dataset, yielding the following results: - **Rollout Time:** Significant reduction in overall rollout time per step. - **Training Rewards:** - The reward metric for training steps shows a positive bias. This is because we exclude the aborted requests (which are typically more difficult and have lower rewards) from the reward calculation. - The reward metric for validation steps remains accurate and aligns with the baseline. This is because the cancellation logic is not triggered during validation, ensuring a fair and complete evaluation. This feature modifies `sglang_rollout.py` and `metric_utils.py`. To use it, follow the standard setup and then run the training script with the over-sample parameters. https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/release_log/over_sample.md The design is centered on three main functions that orchestrate the over-sampling logic: `run_with_cancellation`, `process_request_with_monitoring`, and `monitor_and_cancel`. These functions rely on global variables, such as `all_tasks` and `completion_lock`, to manage state. - **`run_with_cancellation`:** This is the entry point. It launches all requests as `process_request_with_monitoring` tasks concurrently with a single `monitor_and_cancel` task. It uses `asyncio.gather` to wait for all tasks to complete (or be canceled) and converts any exceptions from canceled tasks into padding requests before returning the final output. - **`process_request_with_monitoring`:** This async function handles a single request. It waits for the request to complete using `_async_rollout_a_request` and then checks a shared counter, `completed_count`, using a `completion_lock` for thread safety. If the target completion count has not been reached, it returns the real result. If the target has been met, it returns padding data instead, effectively "discarding" the late result. - **`monitor_and_cancel`:** This is a separate async task that polls the `completed_count`. Once the count reaches the `target_completion` threshold, it immediately cancels all remaining tasks and sends an `abort_requests` signal to the SGLang engine, halting any ongoing GPU computation for those requests. Key code changes: - **`sglang_rollout.py`**: - Adds the three core asynchronous functions for the over-sample strategy. - The `AsyncEngine` class now includes a new `abort_request` method that calls the synchronous `abort_request` in the `tokenizer_manager`. - **`metric_utils.py`**: - The `compute_data_metrics` function is updated to exclude the aborted requests (identified by padding) from the denominator when calculating average rewards during training. This prevents the training reward from being artificially lowered by the zero-reward aborted requests. This implementation is designed to be a straightforward and effective solution for the long-tail problem, though some aspects of the asynchronous design and the impact on training variance require further investigation. --------- Co-authored-by: zhaochenyang <[email protected]> Co-authored-by: PopSoda2002 <[email protected]> Co-authored-by: ChangyiYang <[email protected]> Co-authored-by: PrinsYin <[email protected]> Co-authored-by: WindowsXp-Beta <[email protected]> Co-authored-by: zhaochenyang20 <[email protected]>
1 parent 8fcfdf7 commit 8e9bc73

File tree

7 files changed

+242
-48
lines changed

7 files changed

+242
-48
lines changed

examples/sglang_multiturn/run_qwen2_3b_dapo_multiturn.sh renamed to examples/sglang_multiturn/run_qwen3_4b_dapo_multiturn.sh

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ hf download \
1717
--repo-type dataset \
1818
--local-dir $HOME/data/Maxwell-Jia/AIME_2024
1919

20+
# Note that this script is using AgentLoop instead of SGLang Multi-Turn
21+
# We are concerned that the reward is not actually converge, since the
22+
# reward of retool is encouraging the model to generate more turns to
23+
# call more tools. The answers are not actually correct.
2024

2125
python3 -m verl.trainer.main_ppo \
2226
algorithm.adv_estimator=grpo \
@@ -34,7 +38,7 @@ python3 -m verl.trainer.main_ppo \
3438
data.custom_cls.name=CustomRLHFDataset \
3539
custom_reward_function.path=$PROJECT_DIR/recipe/retool/retool.py \
3640
custom_reward_function.name=compute_score \
37-
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
41+
actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507 \
3842
actor_rollout_ref.model.use_remove_padding=True \
3943
actor_rollout_ref.model.enable_gradient_checkpointing=True \
4044
actor_rollout_ref.actor.use_kl_loss=False \
@@ -43,13 +47,16 @@ python3 -m verl.trainer.main_ppo \
4347
actor_rollout_ref.actor.clip_ratio_high=0.28 \
4448
actor_rollout_ref.actor.clip_ratio_c=10.0 \
4549
actor_rollout_ref.actor.optim.lr=1e-6 \
46-
actor_rollout_ref.actor.use_dynamic_bsz=True \
47-
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
48-
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=1024 \
50+
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
51+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
52+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \
53+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
54+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
4955
actor_rollout_ref.rollout.name=sglang \
5056
actor_rollout_ref.rollout.mode=async \
5157
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
52-
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
58+
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
59+
actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 \
5360
actor_rollout_ref.rollout.multi_stage_wake_up=True \
5461
actor_rollout_ref.rollout.multi_turn.enable=True \
5562
actor_rollout_ref.rollout.multi_turn.max_user_turns=16 \
@@ -62,8 +69,8 @@ python3 -m verl.trainer.main_ppo \
6269
actor_rollout_ref.rollout.val_kwargs.n=30 \
6370
trainer.logger=['console','wandb'] \
6471
trainer.project_name=sglang-dapo-multiturn \
65-
trainer.experiment_name=qwen2_5-3b_dapo_multiturn \
66-
trainer.n_gpus_per_node=4 \
72+
trainer.experiment_name=qwen3-4b_dapo_multiturn \
73+
trainer.n_gpus_per_node=8 \
6774
trainer.log_val_generations=20 \
6875
trainer.val_before_train=True \
6976
trainer.nnodes=1 \

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ actor_rollout_ref:
163163
disable_log_stats: true
164164
do_sample: true
165165
'n': 1
166+
over_sample_rate: 0
166167
multi_stage_wake_up: false
167168
engine_kwargs:
168169
vllm:

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ actor_rollout_ref:
138138
disable_log_stats: true
139139
do_sample: true
140140
'n': 1
141+
over_sample_rate: 0
141142
multi_stage_wake_up: false
142143
engine_kwargs:
143144
vllm:

verl/trainer/config/rollout/rollout.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@ do_sample: True
8282
# number of responses (i.e. num sample times). > 1 for grpo
8383
n: 1
8484

85-
# Whether to wake up inference engine in multi-stage to reduce peak memory during training-rollout transition.
85+
# The over_sample_rate parameter controls the early termination threshold for training rollouts,
86+
# where the system will abort remaining requests when (1 - over_sample_rate) * total_requests completions are reached.
87+
over_sample_rate: 0
88+
89+
# Whether to wake up inference engine in multi-stage for SGLang
90+
# to reduce peak memory during training-rollout transition.
91+
# This is only effective for SGLang rollout.
8692
multi_stage_wake_up: false
8793

8894
# Extra inference engine arguments (vllm, sglang).

verl/trainer/ppo/metric_utils.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,20 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
118118
prompt_length = response_info["prompt_length"]
119119
response_length = response_info["response_length"]
120120

121+
aborted_mask = (response_length == 0).bool()
122+
non_aborted_mask = ~aborted_mask
123+
124+
non_aborted_sequence_score = sequence_score[non_aborted_mask]
125+
non_aborted_sequence_reward = sequence_reward[non_aborted_mask]
126+
127+
score_mean = torch.mean(non_aborted_sequence_score).detach().item()
128+
score_max = torch.max(non_aborted_sequence_score).detach().item()
129+
score_min = torch.min(non_aborted_sequence_score).detach().item()
130+
131+
reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
132+
reward_max = torch.max(non_aborted_sequence_reward).detach().item()
133+
reward_min = torch.min(non_aborted_sequence_reward).detach().item()
134+
121135
valid_adv = torch.masked_select(advantages, response_mask)
122136
valid_returns = torch.masked_select(returns, response_mask)
123137

@@ -127,15 +141,30 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
127141
return_diff_var = torch.var(valid_returns - valid_values)
128142
return_var = torch.var(valid_returns)
129143

144+
# Aborted samples and non-aborted response length statistics
145+
# response_length_non_aborted/*: statistics computed on non-aborted samples only
146+
aborted_ratio = torch.mean(aborted_mask.float()).detach().item()
147+
148+
non_aborted_response_length = response_length[non_aborted_mask]
149+
if non_aborted_response_length.numel() > 0:
150+
non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()
151+
non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()
152+
non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()
153+
non_aborted_response_length_clip_ratio = (
154+
torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()
155+
)
156+
else:
157+
raise ValueError("All samples are aborted, this should not happen.")
158+
130159
metrics = {
131160
# score
132-
"critic/score/mean": torch.mean(sequence_score).detach().item(),
133-
"critic/score/max": torch.max(sequence_score).detach().item(),
134-
"critic/score/min": torch.min(sequence_score).detach().item(),
161+
"critic/score/mean": score_mean,
162+
"critic/score/max": score_max,
163+
"critic/score/min": score_min,
135164
# reward
136-
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
137-
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
138-
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
165+
"critic/rewards/mean": reward_mean,
166+
"critic/rewards/max": reward_max,
167+
"critic/rewards/min": reward_min,
139168
# adv
140169
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
141170
"critic/advantages/max": torch.max(valid_adv).detach().item(),
@@ -163,6 +192,15 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
163192
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
164193
.detach()
165194
.item(),
195+
# response length (non-aborted only)
196+
# These statistics exclude aborted samples to avoid skew from zeros
197+
"response_length_non_aborted/mean": non_aborted_response_length_mean,
198+
"response_length_non_aborted/max": non_aborted_response_length_max,
199+
"response_length_non_aborted/min": non_aborted_response_length_min,
200+
"response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio,
201+
# aborted ratio
202+
# Fraction of samples whose response length is zero
203+
"response/aborted_ratio": aborted_ratio,
166204
# prompt length
167205
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
168206
"prompt_length/max": torch.max(prompt_length).detach().item(),

verl/workers/config/rollout.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class RolloutConfig(BaseConfig):
8686
do_sample: bool = True
8787
n: int = 1
8888

89+
# Early termination threshold for multi-turn rollout in sglang.
90+
# Abort remaining requests when (1 - over_sample_rate) * total_requests are completed.
91+
over_sample_rate: float = 0.0
92+
8993
prompt_length: int = 512
9094
response_length: int = 512
9195

0 commit comments

Comments
 (0)