@@ -975,18 +975,19 @@ def compute_policy_loss_vanilla(
975
975
return pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower
976
976
977
977
978
- @register_policy_loss ("adc " )
979
- def compute_policy_loss_adc (
978
+ @register_policy_loss ("archer " )
979
+ def compute_policy_loss_archer (
980
980
old_log_prob : torch .Tensor ,
981
981
log_prob : torch .Tensor ,
982
982
advantages : torch .Tensor ,
983
983
response_mask : torch .Tensor ,
984
984
loss_agg_mode : str = "token-mean" ,
985
985
config : Optional [DictConfig | AlgoConfig ] = None ,
986
986
rollout_log_probs : torch .Tensor | None = None ,
987
+ entropy : torch .Tensor | None = None ,
987
988
) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
988
989
"""
989
- ADC policy loss (Asymmetric Dual-Clipping):
990
+ ancher policy loss (Asymmetric Dual-Clipping):
990
991
- For advantages > 0: invert importance ratio (use 1/ratio).
991
992
- Extend dual-clip to advantages > 0 as well.
992
993
- Use Soft Clip in dual-clip region to limit weight while preserving gradients.
@@ -995,71 +996,43 @@ def compute_policy_loss_adc(
995
996
996
997
assert config is not None
997
998
assert not isinstance (config , AlgoConfig )
998
- clip_ratio = config .clip_ratio
999
- clip_ratio_low = config .clip_ratio_low if config .clip_ratio_low is not None else clip_ratio
1000
- clip_ratio_high = config .clip_ratio_high if config .clip_ratio_high is not None else clip_ratio
999
+ clip_ratio_low = config .clip_ratio_low
1000
+ clip_ratio_high = config .clip_ratio_hig
1001
1001
clip_ratio_c = config .get ("clip_ratio_c" , 3.0 )
1002
+ token_entropy_quantile = config .get ("token_entropy_quantile" , 0.8 )
1003
+ masked_entropy = torch .where (response_mask .bool (), entropy .detach (), torch .nan ) # (bsz, response_length)
1004
+ q80 = torch .nanquantile (masked_entropy , q = token_entropy_quantile , dim = - 1 , keepdim = True ) # (bsz, 1)
1005
+ high_entropy_mask = (masked_entropy <= q80 ) & response_mask # only low entropy token is True
1002
1006
1003
- cliprange = clip_ratio
1004
- cliprange_low = clip_ratio_low
1005
- cliprange_high = clip_ratio_high
1006
-
1007
- assert clip_ratio_c > 1.0 , (
1008
- "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: { clip_ratio_c } ."
1009
- )
1010
-
1011
- negative_approx_kl = log_prob - old_log_prob
1012
- # Clamp for stability
1013
- negative_approx_kl = torch .clamp (negative_approx_kl , min = - 20.0 , max = 20.0 )
1014
-
1015
- # Standard ratio
1016
- ratio = torch .exp (negative_approx_kl )
1017
- ppo_kl = verl_F .masked_mean (- negative_approx_kl , response_mask )
1018
-
1019
- # For A>0: invert IS (use 1 / ratio). For A<=0: use standard ratio.
1020
- ratio_adc = torch .where (advantages > 0 , 1.0 / ratio , ratio )
1021
-
1022
- # Standard PPO loss (base branch for gradient direction)
1023
- pg_losses1 = - advantages * ratio_adc
1024
-
1025
- if cliprange_low is None :
1026
- cliprange_low = cliprange
1027
- if cliprange_high is None :
1028
- cliprange_high = cliprange
1029
-
1030
- pos_mask = advantages > 0
1031
- neg_mask = advantages < 0
1007
+ ratio = torch .exp (torch .clamp (log_prob - old_log_prob , min = - 20.0 , max = 20.0 ))
1032
1008
1033
- # Standard PPO clip (Hard Clip) on the ratio driving gradients
1034
- pg_losses2 = - advantages * torch .clamp (ratio_adc , 1 - cliprange_low , 1 + cliprange_high )
1035
- clip_pg_losses_base = torch .maximum (pg_losses1 , pg_losses2 )
1009
+ negative_clip_ratio = torch .where (high_entropy_mask , torch .clamp (ratio , min = 1 - clip_ratio_low , max = None ), torch .clamp (ratio , min = 1 - clip_ratio_high , max = None ))
1010
+ positive_clip_ratio = torch .where (high_entropy_mask , torch .clamp (ratio , min = None , max = 1 + clip_ratio_low ), torch .clamp (ratio , min = None , max = 1 + clip_ratio_high ))
1036
1011
1037
- pg_losses_dual = - advantages * clip_ratio_c
1012
+ clip_ratio = torch . where ( advantages < 0 , negative_clip_ratio , positive_clip_ratio )
1038
1013
1039
- # Apply asymmetric dual-clip selection:
1040
- # - adv > 0: cap magnitude from below via min(base, -A * clip_ratio_c)
1041
- # - adv < 0: cap magnitude from above via max(base, -A * clip_ratio_c)
1042
- pg_losses_pos = torch .maximum (clip_pg_losses_base , pg_losses_dual )
1043
- pg_losses_neg = torch .minimum (clip_pg_losses_base , pg_losses_dual )
1044
- pg_losses = torch .where (pos_mask , pg_losses_pos , torch .where (neg_mask , pg_losses_neg , clip_pg_losses_base ))
1014
+ pg_clipfrac_upper = verl_F .masked_mean (torch .gt (ratio , clip_ratio ).float (), response_mask )
1015
+ pg_clipfrac_lower = verl_F .masked_mean (torch .lt (ratio , clip_ratio ).float (), response_mask )
1045
1016
1046
- # Metrics
1047
- pg_clipfrac = verl_F . masked_mean ( torch . gt ( pg_losses2 , pg_losses1 ). float (), response_mask )
1017
+ negative_pg_losses_clip = - advantages * negative_clip_ratio
1018
+ positive_pg_losses_clip = - advantages * ( positive_clip_ratio / positive_clip_ratio . detach ()) / positive_clip_ratio . detach ( )
1048
1019
1049
- # Dual-clip trigger fraction (both sides), measured by ratio exceeding clip_ratio_c
1050
- # For adv>0 we monitor the inverted ratio (ratio_adc); for adv<0 we monitor the standard ratio.
1051
- lower_clip_pos = pos_mask & (ratio_adc > clip_ratio_c )
1052
- lower_clip_neg = neg_mask & ( ratio > clip_ratio_c )
1053
- pg_clipfrac_lower = verl_F . masked_mean (( lower_clip_pos | lower_clip_neg ). float (), response_mask )
1020
+ negative_dual_clip_ratio = torch . clamp ( negative_clip_ratio , min = None , max = clip_ratio_c )
1021
+ negative_clipped_mask = torch . gt ( negative_clip_ratio , negative_dual_clip_ratio )
1022
+ negative_pg_clipfrac_dual = verl_F . masked_mean ( negative_clipped_mask . float (), response_mask & (advantages < 0 ) )
1023
+ negative_pg_losses_dual = - advantages * negative_dual_clip_ratio . detach () * log_prob
1024
+ negative_pg_losses = torch . where ( negative_clipped_mask , negative_pg_losses_dual , negative_pg_losses_clip )
1054
1025
1055
- if config .tis_imp_ratio_cap > 0 and rollout_log_probs is not None :
1056
- tis_imp_ratio = torch .exp (old_log_prob - rollout_log_probs )
1057
- tis_imp_ratio = torch .clamp (tis_imp_ratio , max = config .tis_imp_ratio_cap )
1058
- pg_losses = pg_losses * tis_imp_ratio
1026
+ positive_dual_clip_ratio = torch .clamp (1 / positive_clip_ratio , min = None , max = clip_ratio_c )
1027
+ positive_clipped_mask = torch .gt (1 / positive_clip_ratio , positive_dual_clip_ratio )
1028
+ positive_pg_clipfrac_dual = verl_F .masked_mean (positive_clipped_mask .float (), response_mask & (advantages > 0 ))
1029
+ positive_pg_losses_dual = - advantages * positive_dual_clip_ratio .detach () * log_prob
1030
+ positive_pg_losses = torch .where (positive_clipped_mask , positive_pg_losses_dual , positive_pg_losses_clip )
1059
1031
1032
+ pg_losses = torch .where (advantages < 0 , negative_pg_losses , positive_pg_losses )
1060
1033
pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
1061
1034
1062
- return pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower
1035
+ return pg_loss , pg_clipfrac_upper , pg_clipfrac_lower , negative_pg_clipfrac_dual , positive_pg_clipfrac_dual
1063
1036
1064
1037
@register_policy_loss ("gspo" )
1065
1038
def compute_policy_loss_gspo (
0 commit comments