@@ -893,17 +893,17 @@ def compute_policy_loss_gspo(
893
893
log_prob : torch .Tensor ,
894
894
advantages : torch .Tensor ,
895
895
response_mask : torch .Tensor ,
896
- loss_agg_mode : str = "token-mean" ,
896
+ loss_agg_mode : str = "seq-mean- token-mean" ,
897
897
config : Optional [DictConfig | ActorConfig ] = None ,
898
898
) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
899
899
"""
900
900
Compute the clipped policy objective for vanilla GSPO (sequence-level optimization).
901
901
See https://arxiv.org/pdf/2507.18071 for more details.
902
-
902
+
903
903
Implements equations 6-8 from the paper:
904
904
J_GSPO(θ) = E[x~D, {y_i}_{i=1}^G ~ π_θold(·|x)] [1/G Σ_{i=1}^G min(s_i(θ)Â_i, clip(s_i(θ), 1-ε, 1+ε)Â_i)]
905
905
where s_i(θ) = (π_θ(y_i|x)/π_θold(y_i|x))^(1/|y_i|)
906
-
906
+
907
907
Args:
908
908
old_log_prob (torch.Tensor):
909
909
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
@@ -914,8 +914,8 @@ def compute_policy_loss_gspo(
914
914
response_mask (torch.Tensor):
915
915
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
916
916
loss_agg_mode (str, optional):
917
- Aggregation mode for `agg_loss`. Defaults to " token-mean".
918
- config (Optional[DictConfig | ActorConfig]):
917
+ Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean- token-mean".
918
+ config (Optional[DictConfig | ActorConfig]):
919
919
Configuration parameters
920
920
"""
921
921
@@ -929,7 +929,9 @@ def compute_policy_loss_gspo(
929
929
# Clamp negative_approx_kl for numerical stability
930
930
negative_approx_kl = torch .clamp (negative_approx_kl , min = - 20.0 , max = 20.0 )
931
931
932
- # compute sequence-level importance ratio: si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) = exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
932
+ # compute sequence-level importance ratio:
933
+ # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|)
934
+ # = exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
933
935
seq_lengths = torch .sum (response_mask , dim = - 1 ).clamp (min = 1 )
934
936
negative_approx_kl_seq = torch .sum (negative_approx_kl * response_mask , dim = - 1 ) / seq_lengths
935
937
@@ -940,27 +942,26 @@ def compute_policy_loss_gspo(
940
942
seq_ratio_expanded = seq_importance_ratio .unsqueeze (- 1 ).expand_as (log_prob )
941
943
942
944
pg_losses1 = - advantages * seq_ratio_expanded
943
- pg_losses2 = - advantages * torch .clamp (
944
- seq_ratio_expanded , 1 - clip_ratio_low , 1 + clip_ratio_high
945
- )
945
+ pg_losses2 = - advantages * torch .clamp (seq_ratio_expanded , 1 - clip_ratio_low , 1 + clip_ratio_high )
946
946
pg_losses = torch .maximum (pg_losses1 , pg_losses2 )
947
- pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
947
+ pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = "seq-mean-token-mean" )
948
948
949
949
# For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)
950
950
pg_clipfrac = verl_F .masked_mean (torch .gt (pg_losses2 , pg_losses1 ).float (), response_mask )
951
951
pg_clipfrac_lower = torch .tensor (0.0 , device = pg_loss .device )
952
-
952
+
953
953
ppo_kl = verl_F .masked_mean (- negative_approx_kl , response_mask )
954
954
955
955
return pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower
956
956
957
+
957
958
@register_policy_loss ("gspo_token" )
958
959
def compute_policy_loss_gspo_token (
959
960
old_log_prob : torch .Tensor ,
960
961
log_prob : torch .Tensor ,
961
962
advantages : torch .Tensor ,
962
963
response_mask : torch .Tensor ,
963
- loss_agg_mode : str = "token-mean" ,
964
+ loss_agg_mode : str = "seq-mean- token-mean" ,
964
965
config : Optional [DictConfig | ActorConfig ] = None ,
965
966
) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
966
967
"""
@@ -978,7 +979,7 @@ def compute_policy_loss_gspo_token(
978
979
response_mask (torch.Tensor):
979
980
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
980
981
loss_agg_mode (str, optional):
981
- Aggregation mode for `agg_loss`. Defaults to " token-mean".
982
+ Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean- token-mean".
982
983
"""
983
984
984
985
assert config is not None
@@ -991,28 +992,29 @@ def compute_policy_loss_gspo_token(
991
992
# Clamp negative_approx_kl for numerical stability
992
993
negative_approx_kl = torch .clamp (negative_approx_kl , min = - 20.0 , max = 20.0 )
993
994
994
- # compute sequence-level importance ratio: si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) = exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
995
+ # compute sequence-level importance ratio:
996
+ # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =
997
+ # exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
995
998
seq_lengths = torch .sum (response_mask , dim = - 1 ).clamp (min = 1 )
996
999
negative_approx_kl_seq = torch .sum (negative_approx_kl * response_mask , dim = - 1 ) / seq_lengths
997
1000
998
- # Combined ratio at token level: s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
1001
+ # Combined ratio at token level:
1002
+ # s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
999
1003
# In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]
1000
1004
log_seq_importance_ratio = negative_approx_kl_seq .detach ().unsqueeze (- 1 ) + log_prob - log_prob .detach ()
1001
1005
1002
1006
# finaly exp() to remove log
1003
1007
seq_importance_ratio = torch .exp (log_seq_importance_ratio )
1004
1008
1005
1009
pg_losses1 = - advantages * seq_importance_ratio
1006
- pg_losses2 = - advantages * torch .clamp (
1007
- seq_importance_ratio , 1 - clip_ratio_low , 1 + clip_ratio_high
1008
- )
1010
+ pg_losses2 = - advantages * torch .clamp (seq_importance_ratio , 1 - clip_ratio_low , 1 + clip_ratio_high )
1009
1011
pg_losses = torch .maximum (pg_losses1 , pg_losses2 )
1010
- pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
1012
+ pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = "seq-mean-token-mean" )
1011
1013
1012
1014
# For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)
1013
1015
pg_clipfrac = verl_F .masked_mean (torch .gt (pg_losses2 , pg_losses1 ).float (), response_mask )
1014
1016
pg_clipfrac_lower = torch .tensor (0.0 , device = pg_loss .device )
1015
-
1017
+
1016
1018
ppo_kl = verl_F .masked_mean (- negative_approx_kl , response_mask )
1017
1019
1018
1020
return pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower
0 commit comments