Skip to content

Commit 4b247b2

Browse files
committed
refactor: update loss aggregation mode in GSPO functions to seq-mean-token-mean
1 parent 7e95f77 commit 4b247b2

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

verl/trainer/ppo/core_algos.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -893,17 +893,17 @@ def compute_policy_loss_gspo(
893893
log_prob: torch.Tensor,
894894
advantages: torch.Tensor,
895895
response_mask: torch.Tensor,
896-
loss_agg_mode: str = "token-mean",
896+
loss_agg_mode: str = "seq-mean-token-mean",
897897
config: Optional[DictConfig | ActorConfig] = None,
898898
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
899899
"""
900900
Compute the clipped policy objective for vanilla GSPO (sequence-level optimization).
901901
See https://arxiv.org/pdf/2507.18071 for more details.
902-
902+
903903
Implements equations 6-8 from the paper:
904904
J_GSPO(θ) = E[x~D, {y_i}_{i=1}^G ~ π_θold(·|x)] [1/G Σ_{i=1}^G min(s_i(θ)Â_i, clip(s_i(θ), 1-ε, 1+ε)Â_i)]
905905
where s_i(θ) = (π_θ(y_i|x)/π_θold(y_i|x))^(1/|y_i|)
906-
906+
907907
Args:
908908
old_log_prob (torch.Tensor):
909909
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
@@ -914,8 +914,8 @@ def compute_policy_loss_gspo(
914914
response_mask (torch.Tensor):
915915
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
916916
loss_agg_mode (str, optional):
917-
Aggregation mode for `agg_loss`. Defaults to "token-mean".
918-
config (Optional[DictConfig | ActorConfig]):
917+
Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean-token-mean".
918+
config (Optional[DictConfig | ActorConfig]):
919919
Configuration parameters
920920
"""
921921

@@ -929,7 +929,9 @@ def compute_policy_loss_gspo(
929929
# Clamp negative_approx_kl for numerical stability
930930
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
931931

932-
# compute sequence-level importance ratio: si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) = exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
932+
# compute sequence-level importance ratio:
933+
# si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|)
934+
# = exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
933935
seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)
934936
negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths
935937

@@ -940,27 +942,26 @@ def compute_policy_loss_gspo(
940942
seq_ratio_expanded = seq_importance_ratio.unsqueeze(-1).expand_as(log_prob)
941943

942944
pg_losses1 = -advantages * seq_ratio_expanded
943-
pg_losses2 = -advantages * torch.clamp(
944-
seq_ratio_expanded, 1 - clip_ratio_low, 1 + clip_ratio_high
945-
)
945+
pg_losses2 = -advantages * torch.clamp(seq_ratio_expanded, 1 - clip_ratio_low, 1 + clip_ratio_high)
946946
pg_losses = torch.maximum(pg_losses1, pg_losses2)
947-
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
947+
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean")
948948

949949
# For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)
950950
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
951951
pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)
952-
952+
953953
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
954954

955955
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
956956

957+
957958
@register_policy_loss("gspo_token")
958959
def compute_policy_loss_gspo_token(
959960
old_log_prob: torch.Tensor,
960961
log_prob: torch.Tensor,
961962
advantages: torch.Tensor,
962963
response_mask: torch.Tensor,
963-
loss_agg_mode: str = "token-mean",
964+
loss_agg_mode: str = "seq-mean-token-mean",
964965
config: Optional[DictConfig | ActorConfig] = None,
965966
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
966967
"""
@@ -978,7 +979,7 @@ def compute_policy_loss_gspo_token(
978979
response_mask (torch.Tensor):
979980
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
980981
loss_agg_mode (str, optional):
981-
Aggregation mode for `agg_loss`. Defaults to "token-mean".
982+
Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean-token-mean".
982983
"""
983984

984985
assert config is not None
@@ -991,28 +992,29 @@ def compute_policy_loss_gspo_token(
991992
# Clamp negative_approx_kl for numerical stability
992993
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
993994

994-
# compute sequence-level importance ratio: si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) = exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
995+
# compute sequence-level importance ratio:
996+
# si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =
997+
# exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
995998
seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)
996999
negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths
9971000

998-
# Combined ratio at token level: s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
1001+
# Combined ratio at token level:
1002+
# s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
9991003
# In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]
10001004
log_seq_importance_ratio = negative_approx_kl_seq.detach().unsqueeze(-1) + log_prob - log_prob.detach()
10011005

10021006
# finaly exp() to remove log
10031007
seq_importance_ratio = torch.exp(log_seq_importance_ratio)
10041008

10051009
pg_losses1 = -advantages * seq_importance_ratio
1006-
pg_losses2 = -advantages * torch.clamp(
1007-
seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high
1008-
)
1010+
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
10091011
pg_losses = torch.maximum(pg_losses1, pg_losses2)
1010-
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
1012+
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean")
10111013

10121014
# For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)
10131015
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
10141016
pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)
1015-
1017+
10161018
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
10171019

10181020
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower

0 commit comments

Comments
 (0)