Skip to content

[algo] feat: add GSPO-token policy loss computation function #2775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 30, 2025

Conversation

0x404
Copy link
Collaborator

@0x404 0x404 commented Jul 28, 2025

What does this PR do?

This PR implements the GSPO-token policy loss calculation proposed by paper https://arxiv.org/pdf/2507.18071

Test

image

Compared GRPO and GSPO under the same settings. GRPO uses the following script:

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.policy_loss.loss_mode="vanilla" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=10 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_gspo_cmp' \
    trainer.experiment_name='qwen2.5-3B-GRPO' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

GSPO uses the following script:

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.policy_loss.loss_mode="gspo" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=10 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_gspo_cmp' \
    trainer.experiment_name='qwen2.5-3B-GRPO' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

API and Usage Example

To use GSPO, users only need to set actor_rollout_ref.actor.policy_loss.loss_mode to gspo.

python3 -m verl.trainer.main_ppo \
  ... \
  actor_rollout_ref.actor.policy_loss.loss_mode="gspo" \
  ...

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the GSPO policy loss function. I've identified potential division-by-zero and floating-point overflows that could disrupt training. Addressing these will make the implementation more robust.

@0x404 0x404 changed the title [algo] feat: add GSPO policy loss computation function [WIP] [algo] feat: add GSPO-token policy loss computation function Jul 28, 2025
@0x404 0x404 marked this pull request as ready for review July 28, 2025 12:11
@vermouth1992
Copy link
Collaborator

Could you create a folder under recipe named gspo, and add two recipe script using QWen2.5 7b math and Qwen3 30b a3b following this?

@0x404
Copy link
Collaborator Author

0x404 commented Jul 28, 2025

Could you create a folder under recipe named gspo, and add two recipe script using QWen2.5 7b math and Qwen3 30b a3b following this?

Of course, working on it.

@0x404 0x404 marked this pull request as draft July 28, 2025 14:09
pg_losses1 = -advantages * seq_importance_ratio
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
pg_losses = torch.maximum(pg_losses1, pg_losses2)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here agg_loss should always be sentence-level?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, "seq-mean-token-sum" looks most suitable. But from the paper's definition, maybe a "batch-mean-group-mean-token-sum" would be more accurate? Which sums the token loss of seq, then takes a mean of group seq loss at the group-level, then takes a mean/sum again at the batch level. there is a similar discussion in #2776

Copy link

@mokeevdmitrii mokeevdmitrii Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is GSPO-token, so it should be exactly seq-mean-token-mean.

EDIT: hope i haven't hallucinated in my nightly calcs

The example in #2776 is about token-mean vs seq-mean).

Referencing the article:

image
  1. First, we calc $s_i(\theta)$ and find token-level loss (inside the last sum, let's name it $pg_{i,t}$.

  2. Second, we sum this loss over every sequence and divide it by sequence length $|y_i|$. This equals to:

seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)

in current verl implementation for seq-mean-token-mean. Let's name this $pg_{i}$.

  1. Finally, we notice that there is a common denumerator, as the loss for a single group is $\dfrac{1}{G} \sum\limits_{i=1}^G pg_i$. So, "batch-mean-group-mean-token-mean" is not needed, as each "group-mean" operation has the same denumerator - this is equiv. to ""batch-group"-mean-token-mean" or simply "seq-mean-token-mean".

Notes:

  1. seq-mean-token-sum would be wrong, as we are summing up $s_i(\theta)$ for each token in GSPO-token, so we must divide by seq-len here.

  2. the fact that token-mean is a valid loss_agg_mode value here makes me a bit sad bcs the article is about sequence-level optimization)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mokeevdmitrii You are right! we should use seq-mean-token-mean when using GSPO, thanks very much!

@CLAassistant
Copy link

CLAassistant commented Jul 28, 2025

CLA assistant check
All committers have signed the CLA.

@BounharAbdelaziz
Copy link
Contributor

Just added the implementation of the sequence level optimization (GSPO vanilla) and improved the GSPO-token implementation by removing some redundant clamp for negative_approx_kl (at, I believe wrong late step).

I also added some details about the equations we are calculating at each step.

Will be sharing some xp results once ready.

@mokeevdmitrii
Copy link

This PR uses default cliprange settings (0.2 default as I am writing this comment, however the authors mention

that the clipping ranges in GSPO and
in previous algorithms (e.g., GRPO) typically differ in order of magnitude due to the distinct definitions
of importance ratios

I think this is worth testing if possible and (at least) mentioning this fact to the user of "loss_mode: gspo"

@BounharAbdelaziz
Copy link
Contributor

This PR uses default cliprange settings (0.2 default as I am writing this comment, however the authors mention

that the clipping ranges in GSPO and
in previous algorithms (e.g., GRPO) typically differ in order of magnitude due to the distinct definitions
of importance ratios

I think this is worth testing if possible and (at least) mentioning this fact to the user of "loss_mode: gspo"

Indeed, it’s recommended (https://x.com/ChujieZheng/status/1948933507696525392) to be between 3e-4 and 4e-4.

Comment on lines 941 to 945
# Expand sequence ratio to match token dimensions: each token in a sequence gets the same ratio
seq_ratio_expanded = seq_importance_ratio.unsqueeze(-1).expand_as(log_prob)

pg_losses1 = -advantages * seq_ratio_expanded
pg_losses2 = -advantages * torch.clamp(seq_ratio_expanded, 1 - clip_ratio_low, 1 + clip_ratio_high)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @BounharAbdelaziz, do you think we really need to keep both GSPO and GSPO-token?

For sentence-level GSPO, it shouldn't need seq_importance_ratio to do expand, but rather directly multiply with sentence-level advantage. However, our advantage here is token-level.

I think we might only need to implement GSPO-token, because GSPO-token and GSPO are equivalent (when the advantage has equal elements across the seq_length dimension). What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer GSPO-token, as it is theoretically equal to GSPO and offers higher flexibility for future potential features

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @0x404, you are right. They fall into the same and you can technically make GSPO-token=GSPO by setting A_{i,t}=A_i. I'm pushing the update.

Thank you @chujiezheng for your input!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@0x404 I see that you already did remove it, thanks!

@0x404
Copy link
Collaborator Author

0x404 commented Jul 29, 2025

image

I think this is worth testing if possible

results based on commit 4b247b2, GSPO using 4e-4 cliprange settings:

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.policy_loss.loss_mode="gspo_token" \
    actor_rollout_ref.actor.clip_ratio_low=0.0004 \
    actor_rollout_ref.actor.clip_ratio_high=0.0004 \
    actor_rollout_ref.actor.clip_ratio=0.0004 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=10 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_gspo_cmp' \
    trainer.experiment_name='qwen2.5-3B-GSPO-clip1e-4' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

@chujiezheng
Copy link
Contributor

Thanks for your support of the GSPO algorithm!

We have updated the experimental details in the GSPO paper: https://arxiv.org/abs/2507.18071

acd0cacf9b3d9ad5311d444cfdcce005

I would suggest to try the following hyperparameters:

  • actor_rollout_ref.actor.clip_ratio_low=3e-4
  • actor_rollout_ref.actor.clip_ratio_high=4e-4

You may also try to turn off the KL loss by setting actor_rollout_ref.actor.kl_loss_coef=0

@0x404
Copy link
Collaborator Author

0x404 commented Jul 29, 2025

@chujiezheng Thanks, will try it out!

@BounharAbdelaziz
Copy link
Contributor

BounharAbdelaziz commented Jul 29, 2025

I did run an experiment using GSM8K on GSPO-token vs GRPO using the following "core" parameters:

GSPO-token:

clip_ratio_low=0.0003 # as recommended by the paper, see Sec. 5.1
clip_ratio_high=0.0004 # as recommended by the paper, see Sec. 5.1
train_batch_size=512
ppo_mini_batch_size=128 # maintain 4 mini-batches as recommended by the paper, see Sec. 5.1
ppo_micro_batch_size_per_gpu=8 # setup depending on your GPU memory
n_resp_per_prompt=16

use_kl_in_reward=false
kl_coef=0.0
use_kl_loss=false
kl_loss_coef=0.0

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8)) # yes this is a lot, and actually the response length mean converges towards 25x

GRPO shares the same setup with the main differences being:

clip_ratio_low=0.2
clip_ratio_high=0.28

Observations

Both algorithms tend to perform very similarly on the test set. Though, it seems like GRPO is learning a bit faster.

When examining the pg clipping factor, we can observe that it's huge sitting at 30% (double the one mentionned in the paper) and stabelizes around 10% while GRPO almost never clips, 0.03% on average. I believe that one need to tune the $\epsilon$ carrefully, perhaps @chujiezheng can share some insights on this?

With the above observations, I am currently running two other experiments for GSPO-token with clip_ratio_high=6e-4and clip_ratio_high=3.5e-4. The goal is to assess whether allowing more clipping might help, or if tighter clipping thresholds improve performance, though higher thresholds could increase the clipping factor even further.

Logs are attached for reference.

Screenshot 2025-07-29 at 22 34 45 Screenshot 2025-07-29 at 22 35 35 Screenshot 2025-07-29 at 22 35 52

@chujiezheng
Copy link
Contributor

chujiezheng commented Jul 29, 2025

Thanks for your test.

Based on the feedback from friends in other labs, the advantage of GSPO over GRPO (w/ Routing Replay) is mainly manifested on MoE models, while they perform similarly on dense models.

We did not do much test for GSPO on dense models internally. But our experience shows that dense models are much easier to RL train. So the results you obtain may be expected.

@chujiezheng
Copy link
Contributor

Also, we did not do test in the zero setting, and we previously found that zero RL is also much easier than in the cold-start setting (perhaps because the response length in the zero setting is usually shorter).

# Calculate log ratio for each token: log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))
negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for numerical stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This clamp is unnecessary? Could you please confirm that @chujiezheng Thanks!

Copy link
Contributor

@chujiezheng chujiezheng Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it is unnecessary here

# Combined ratio at token level:
# s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
# In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]
log_seq_importance_ratio = negative_approx_kl_seq.detach().unsqueeze(-1) + log_prob - log_prob.detach()
Copy link
Contributor

@chujiezheng chujiezheng Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should change the calculation order to avoid precision error:

log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @chujiezheng ! all resolved.

@BounharAbdelaziz
Copy link
Contributor

Also, we did not do test in the zero setting, and we previously found that zero RL is also much easier than in the cold-start setting (perhaps because the response length in the zero setting is usually shorter).

Great! Thanks a lot for sharing your insights!

@vermouth1992
Copy link
Collaborator

I guess this PR should be ready for merge?

@0x404 0x404 marked this pull request as ready for review July 30, 2025 10:38
@BounharAbdelaziz
Copy link
Contributor

I did run an experiment using GSM8K on GSPO-token vs GRPO using the following "core" parameters:

GSPO-token:

clip_ratio_low=0.0003 # as recommended by the paper, see Sec. 5.1
clip_ratio_high=0.0004 # as recommended by the paper, see Sec. 5.1
train_batch_size=512
ppo_mini_batch_size=128 # maintain 4 mini-batches as recommended by the paper, see Sec. 5.1
ppo_micro_batch_size_per_gpu=8 # setup depending on your GPU memory
n_resp_per_prompt=16

use_kl_in_reward=false
kl_coef=0.0
use_kl_loss=false
kl_loss_coef=0.0

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8)) # yes this is a lot, and actually the response length mean converges towards 25x

GRPO shares the same setup with the main differences being:

clip_ratio_low=0.2
clip_ratio_high=0.28

Observations

Both algorithms tend to perform very similarly on the test set. Though, it seems like GRPO is learning a bit faster.

When examining the pg clipping factor, we can observe that it's huge sitting at 30% (double the one mentionned in the paper) and stabelizes around 10% while GRPO almost never clips, 0.03% on average. I believe that one need to tune the ϵ carrefully, perhaps @chujiezheng can share some insights on this?

With the above observations, I am currently running two other experiments for GSPO-token with clip_ratio_high=6e-4and clip_ratio_high=3.5e-4. The goal is to assess whether allowing more clipping might help, or if tighter clipping thresholds improve performance, though higher thresholds could increase the clipping factor even further.

Logs are attached for reference.

Screenshot 2025-07-29 at 22 34 45 Screenshot 2025-07-29 at 22 35 35 Screenshot 2025-07-29 at 22 35 52

Sharing my experiments with the following two extra test for GSPO-token:

clip_ratio_low=0.0003
clip_ratio_high=0.00035
clip_ratio_low=0.0003
clip_ratio_high=0.0006

Results: no noticible difference. This joins the shared insights by @chujiezheng, i.e. GSPO is more suitable for MoEs.

Screenshot 2025-07-30 at 12 41 25 Screenshot 2025-07-30 at 12 42 02 Screenshot 2025-07-30 at 12 42 16 Screenshot 2025-07-30 at 12 44 48

@vermouth1992 vermouth1992 merged commit b75b1f0 into volcengine:main Jul 30, 2025
44 of 53 checks passed
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jul 31, 2025
…ine#2775)

### What does this PR do?

This PR implements the GSPO-token policy loss calculation proposed by
paper https://arxiv.org/pdf/2507.18071

### Test

<img width="1341" height="637" alt="image"
src="https://github.com/user-attachments/assets/bc5e2245-b0f5-4a1f-aa7c-4c2b28d95142"
/>

Compared GRPO and GSPO under the same settings. GRPO uses the following
script:
``` sh
python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.policy_loss.loss_mode="vanilla" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=10 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_gspo_cmp' \
    trainer.experiment_name='qwen2.5-3B-GRPO' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@
```

GSPO uses the following script:
```sh
python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.policy_loss.loss_mode="gspo" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=10 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_gspo_cmp' \
    trainer.experiment_name='qwen2.5-3B-GRPO' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

```

### API and Usage Example

To use GSPO, users only need to set
`actor_rollout_ref.actor.policy_loss.loss_mode` to `gspo`.

```shell
python3 -m verl.trainer.main_ppo \
  ... \
  actor_rollout_ref.actor.policy_loss.loss_mode="gspo" \
  ...
```


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: BounharAbdelaziz <[email protected]>
@xihuai18
Copy link
Contributor

xihuai18 commented Aug 3, 2025

Can someone explain why pg_clipfrac_lower is not used?


    # For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
    pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)

Juniper1021 pushed a commit to Juniper1021/verl that referenced this pull request Aug 7, 2025
…ine#2775)

### What does this PR do?

This PR implements the GSPO-token policy loss calculation proposed by
paper https://arxiv.org/pdf/2507.18071

### Test

<img width="1341" height="637" alt="image"
src="https://github.com/user-attachments/assets/bc5e2245-b0f5-4a1f-aa7c-4c2b28d95142"
/>

Compared GRPO and GSPO under the same settings. GRPO uses the following
script:
``` sh
python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.policy_loss.loss_mode="vanilla" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=10 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_gspo_cmp' \
    trainer.experiment_name='qwen2.5-3B-GRPO' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@
```

GSPO uses the following script:
```sh
python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.policy_loss.loss_mode="gspo" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=10 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_gspo_cmp' \
    trainer.experiment_name='qwen2.5-3B-GRPO' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

```

### API and Usage Example

To use GSPO, users only need to set
`actor_rollout_ref.actor.policy_loss.loss_mode` to `gspo`.

```shell
python3 -m verl.trainer.main_ppo \
  ... \
  actor_rollout_ref.actor.policy_loss.loss_mode="gspo" \
  ...
```


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: BounharAbdelaziz <[email protected]>
@TristanMeng
Copy link

should use Dual-clip ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants