Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/special_e2e/ppo_trainer/run_function_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ RM_PAD=${RM_PAD:-True}
FUSED_KERNELS=${FUSED_KERNELS:-False}
FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend
ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae}
LOSS_MODE=${LOSS_MODE:-vanilla}
USE_KL=${USE_KL:-False}
CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False}
ENABLE_CHUNKED_PREFILL=${ENABLE_CHUNKED_PREFILL:-True} # For vLLM VLM placeholder issue: https://github.com/vllm-project/vllm/issues/15185
Expand Down Expand Up @@ -112,6 +113,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \
actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \
actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \
actor_rollout_ref.actor.policy_loss.loss_mode="${LOSS_MODE}" \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name="${ENGINE}" \
Expand Down
59 changes: 59 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import verl.utils.torch_functional as verl_F
from verl.trainer.config import AlgoConfig
from verl.utils.import_utils import deprecated
from verl.workers.config import ActorConfig

PolicyLossFn = Callable[
[
Expand Down Expand Up @@ -886,6 +887,64 @@ def compute_policy_loss_vanilla(
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower


@register_policy_loss("gspo")
def compute_policy_loss_gspo(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | ActorConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GSPO.

See https://arxiv.org/pdf/2507.18071 for more details.

Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
"""

assert config is not None
assert isinstance(config, ActorConfig)
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio

negative_approx_kl = log_prob - old_log_prob

# compute sequence-level importance ratio
seq_lengths = torch.sum(response_mask, dim=-1)
negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths.clamp(min=1)
log_seq_importance_ratio = negative_approx_kl_seq.detach().unsqueeze(-1) + log_prob - log_prob.detach()
Copy link
Contributor

@chujiezheng chujiezheng Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should change the calculation order to avoid precision error:

log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @chujiezheng ! all resolved.


# Clamp log_seq_importance_ratio for stability
log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, min=-20.0, max=20.0)
seq_importance_ratio = torch.exp(log_seq_importance_ratio)

pg_losses1 = -advantages * seq_importance_ratio
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
pg_losses = torch.maximum(pg_losses1, pg_losses2)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here agg_loss should always be sentence-level?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, "seq-mean-token-sum" looks most suitable. But from the paper's definition, maybe a "batch-mean-group-mean-token-sum" would be more accurate? Which sums the token loss of seq, then takes a mean of group seq loss at the group-level, then takes a mean/sum again at the batch level. there is a similar discussion in #2776

Copy link

@mokeevdmitrii mokeevdmitrii Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is GSPO-token, so it should be exactly seq-mean-token-mean.

EDIT: hope i haven't hallucinated in my nightly calcs

The example in #2776 is about token-mean vs seq-mean).

Referencing the article:

image
  1. First, we calc $s_i(\theta)$ and find token-level loss (inside the last sum, let's name it $pg_{i,t}$.

  2. Second, we sum this loss over every sequence and divide it by sequence length $|y_i|$. This equals to:

seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)

in current verl implementation for seq-mean-token-mean. Let's name this $pg_{i}$.

  1. Finally, we notice that there is a common denumerator, as the loss for a single group is $\dfrac{1}{G} \sum\limits_{i=1}^G pg_i$. So, "batch-mean-group-mean-token-mean" is not needed, as each "group-mean" operation has the same denumerator - this is equiv. to ""batch-group"-mean-token-mean" or simply "seq-mean-token-mean".

Notes:

  1. seq-mean-token-sum would be wrong, as we are summing up $s_i(\theta)$ for each token in GSPO-token, so we must divide by seq-len here.

  2. the fact that token-mean is a valid loss_agg_mode value here makes me a bit sad bcs the article is about sequence-level optimization)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mokeevdmitrii You are right! we should use seq-mean-token-mean when using GSPO, thanks very much!


# For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)

negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower


@register_policy_loss("gpg")
def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None):
"""Adapted from
Expand Down
Loading