Skip to content
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
953eae6
change the default bash of qqwen 2.5
zhaochen20 Aug 4, 2025
0384aef
add engine support
zhaochen20 Aug 5, 2025
a0ab8b0
OVER_SAMPLE_RATE
zhaochen20 Aug 5, 2025
652f0f7
update sgl rollout
zhaochen20 Aug 5, 2025
3fdfd00
is val
zhaochen20 Aug 5, 2025
d37f51d
self.abort
zhaochen20 Aug 5, 2025
51410e7
cancel task
zhaochen20 Aug 5, 2025
de8feb3
cancel then abort
zhaochen20 Aug 5, 2025
da37e6a
fix await
zhaochen20 Aug 5, 2025
5308507
finish over sample
zhaochen20 Aug 5, 2025
25a15b9
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 5, 2025
07ad3f9
Merge branch 'main' into over_sample_sgl
zhaochen20 Aug 5, 2025
1b8bfa9
add log to over sample
zhaochen20 Aug 5, 2025
519a1b0
Add benchmakr script
zhaochen20 Aug 5, 2025
ae43c5a
increase to 45
zhaochen20 Aug 5, 2025
45265ec
incrse
zhaochen20 Aug 5, 2025
ef6ab93
add over sample
zhaochen20 Aug 5, 2025
264bed6
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 5, 2025
0c63925
Merge branch 'main' into over_sample_sgl
zhaochen20 Aug 5, 2025
5d3b970
fix reward and loss cal
zhaochen20 Aug 5, 2025
89da6d7
revert loss agg
zhaochen20 Aug 5, 2025
614f0ad
revert reward padding
zhaochen20 Aug 5, 2025
08b7e02
revert reward padding
zhaochen20 Aug 5, 2025
36abd78
finish
zhaochen20 Aug 5, 2025
b979a73
revert non_aborted_mask
zhaochen20 Aug 5, 2025
b4fdfcf
clean up codes
zhaochen20 Aug 5, 2025
88a23ce
Merge branch 'over_sample' of https://github.com/zhaochenyang20/verl …
zhaochen20 Aug 5, 2025
4e2316b
clean up codes
zhaochen20 Aug 5, 2025
3a37c6e
modify over sample rate
zhaochen20 Aug 5, 2025
94d7c68
update examples
zhaochen20 Aug 5, 2025
5fd6b67
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 6, 2025
1b6eb79
Merge branch 'main' into over_sample
zhaochen20 Aug 6, 2025
271b493
[tmp file] benchmark
zhaochen20 Aug 6, 2025
6e98696
upd
zhaochen20 Aug 6, 2025
a32cfb4
upd
zhaochen20 Aug 6, 2025
ea29d31
fix
zhaochen20 Aug 6, 2025
cc8fb16
launch
zhaochen20 Aug 6, 2025
60c1610
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 6, 2025
96a8ada
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 7, 2025
903c2d3
Merge branch 'main' into over_sample
zhaochen20 Aug 7, 2025
53099d1
discard dynamic batch
zhaochen20 Aug 7, 2025
3d7fa4a
add micro_32
zhaochen20 Aug 7, 2025
5a733f8
log bs
zhaochen20 Aug 7, 2025
4cc54b8
Add log_prob_micro_batch_size_per_gpu
zhaochen20 Aug 7, 2025
c67ebff
Add ppo_max_token_len_per_gpu
zhaochen20 Aug 7, 2025
43a9c31
add qwen3 4b
zhaochen20 Aug 7, 2025
8e17a93
use qwen3 4b
zhaochen20 Aug 7, 2025
71687d7
descrease micro
zhaochen20 Aug 7, 2025
b084dfd
update dapo for benchmarking
zhaochen20 Aug 7, 2025
b66327e
fix hydra
zhaochen20 Aug 7, 2025
23c2957
use 8 gpu
zhaochen20 Aug 7, 2025
b944230
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 7, 2025
7ee096e
Merge branch 'main' into over_sample
zhaochen20 Aug 7, 2025
97d137d
delete not used variale
PopSoda2002 Aug 7, 2025
35b65eb
use full set to eval
zhaochen20 Aug 7, 2025
6f3e49e
Merge branch 'over_sample' of https://github.com/zhaochenyang20/verl …
zhaochen20 Aug 7, 2025
b0f0245
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 11, 2025
2f5d4f5
Merge branch 'main' into over_sample
zhaochen20 Aug 11, 2025
4c882d1
finish over sampling
zhaochen20 Aug 11, 2025
4822564
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 13, 2025
b28f178
Merge branch 'main' into over_sample
zhaochen20 Aug 13, 2025
a794b29
fix rollout config
zhaochen20 Aug 14, 2025
cdc2eff
do not capture all
zhaochen20 Aug 14, 2025
77e8491
clean up validation
zhaochen20 Aug 14, 2025
4fd6b36
fix conflicts register
zhaochen20 Aug 14, 2025
58f3570
fix future
zhaochen20 Aug 14, 2025
23452ed
use event loop
zhaochen20 Aug 14, 2025
3ce9b1f
get event loop
zhaochen20 Aug 14, 2025
d6128a2
upd, fix up
zhaochen20 Aug 14, 2025
ba87964
fix up
zhaochen20 Aug 14, 2025
96c5d75
delete comment
zhaochen20 Aug 14, 2025
c1e8884
Merge branch 'main' of https://github.com/zhaochenyang20/verl
zhaochen20 Aug 14, 2025
06beaf0
Merge branch 'main' into over_sample
zhaochen20 Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ ulimit -n 65535
PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"

function now() {
date '+%d-%H-%M'
}

EXPERIMENT_NAME="qwen2.5-3b_baseline_$(now)"

python3 -m verl.trainer.main_ppo \
--config-path="$CONFIG_PATH" \
--config-name='gsm8k_multiturn_grpo' \
Expand All @@ -31,21 +37,24 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.multi_stage_wake_up=True \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.rollout.over_sample_rate=0 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='gsm8k_async_rl' \
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \
trainer.project_name='multi-turn-grpo-qwen2.5-3b-sglang' \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=20 \
trainer.val_before_train=True \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ python3 -m verl.trainer.main_ppo \
data.custom_cls.name=CustomRLHFDataset \
custom_reward_function.path=$PROJECT_DIR/recipe/retool/retool.py \
custom_reward_function.name=compute_score \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.use_kl_loss=False \
Expand All @@ -43,13 +43,16 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.clip_ratio_high=0.28 \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=8 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=1024 \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 \
actor_rollout_ref.rollout.multi_stage_wake_up=True \
actor_rollout_ref.rollout.multi_turn.enable=True \
actor_rollout_ref.rollout.multi_turn.max_user_turns=16 \
Expand All @@ -62,8 +65,8 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.val_kwargs.n=30 \
trainer.logger=['console','wandb'] \
trainer.project_name=sglang-dapo-multiturn \
trainer.experiment_name=qwen2_5-3b_dapo_multiturn \
trainer.n_gpus_per_node=4 \
trainer.experiment_name=qwen3-4b_dapo_multiturn \
trainer.n_gpus_per_node=8 \
trainer.log_val_generations=20 \
trainer.val_before_train=True \
trainer.nnodes=1 \
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ actor_rollout_ref:
disable_log_stats: true
do_sample: true
'n': 1
over_sample_rate: 0
multi_stage_wake_up: false
engine_kwargs:
vllm:
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ actor_rollout_ref:
disable_log_stats: true
do_sample: true
'n': 1
over_sample_rate: 0
multi_stage_wake_up: false
engine_kwargs:
vllm:
Expand Down
8 changes: 7 additions & 1 deletion verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ do_sample: True
# number of responses (i.e. num sample times). > 1 for grpo
n: 1

# Whether to wake up inference engine in multi-stage to reduce peak memory during training-rollout transition.
# The over_sample_rate parameter controls the early termination threshold for training rollouts,
# where the system will abort remaining requests when (1 - over_sample_rate) * total_requests completions are reached.
over_sample_rate: 0

# Whether to wake up inference engine in multi-stage for SGLang
# to reduce peak memory during training-rollout transition.
# This is only effective for SGLang rollout.
multi_stage_wake_up: false

# Extra inference engine arguments (vllm, sglang).
Expand Down
50 changes: 44 additions & 6 deletions verl/trainer/ppo/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]

aborted_mask = (response_length == 0).bool()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may not be correct

non_aborted_mask = ~aborted_mask

non_aborted_sequence_score = sequence_score[non_aborted_mask]
non_aborted_sequence_reward = sequence_reward[non_aborted_mask]

score_mean = torch.mean(non_aborted_sequence_score).detach().item()
score_max = torch.max(non_aborted_sequence_score).detach().item()
score_min = torch.min(non_aborted_sequence_score).detach().item()

reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
reward_max = torch.max(non_aborted_sequence_reward).detach().item()
reward_min = torch.min(non_aborted_sequence_reward).detach().item()

valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)

Expand All @@ -127,15 +141,30 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)

# Aborted samples and non-aborted response length statistics
# response_length_non_aborted/*: statistics computed on non-aborted samples only
aborted_ratio = torch.mean(aborted_mask.float()).detach().item()

non_aborted_response_length = response_length[non_aborted_mask]
if non_aborted_response_length.numel() > 0:
non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()
non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()
non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()
non_aborted_response_length_clip_ratio = (
torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()
)
else:
raise ValueError("All samples are aborted, this should not happen.")

metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
"critic/score/mean": score_mean,
"critic/score/max": score_max,
"critic/score/min": score_min,
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
"critic/rewards/mean": reward_mean,
"critic/rewards/max": reward_max,
"critic/rewards/min": reward_min,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to apply the same non_aborted_mask logic to the response_length metrics as well ( mean/max/min response length)? Otherwise, the metrics might still include padded responses from aborted requests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed accordingly. Adding the response length mean after/before abort. And I added the drop rate metric.

# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
Expand Down Expand Up @@ -163,6 +192,15 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# response length (non-aborted only)
# These statistics exclude aborted samples to avoid skew from zeros
"response_length_non_aborted/mean": non_aborted_response_length_mean,
"response_length_non_aborted/max": non_aborted_response_length_max,
"response_length_non_aborted/min": non_aborted_response_length_min,
"response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio,
# aborted ratio
# Fraction of samples whose response length is zero
"response/aborted_ratio": aborted_ratio,
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
Expand Down
Loading
Loading