Skip to content

Commit cbbd82d

Browse files
author
huangjunyi.0
committed
support archer
1 parent 09dac73 commit cbbd82d

File tree

2 files changed

+76
-79
lines changed

2 files changed

+76
-79
lines changed

verl/trainer/ppo/core_algos.py

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -975,18 +975,19 @@ def compute_policy_loss_vanilla(
975975
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
976976

977977

978-
@register_policy_loss("adc")
979-
def compute_policy_loss_adc(
978+
@register_policy_loss("archer")
979+
def compute_policy_loss_archer(
980980
old_log_prob: torch.Tensor,
981981
log_prob: torch.Tensor,
982982
advantages: torch.Tensor,
983983
response_mask: torch.Tensor,
984984
loss_agg_mode: str = "token-mean",
985985
config: Optional[DictConfig | AlgoConfig] = None,
986986
rollout_log_probs: torch.Tensor | None = None,
987+
entropy: torch.Tensor | None = None,
987988
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
988989
"""
989-
ADC policy loss (Asymmetric Dual-Clipping):
990+
ancher policy loss (Asymmetric Dual-Clipping):
990991
- For advantages > 0: invert importance ratio (use 1/ratio).
991992
- Extend dual-clip to advantages > 0 as well.
992993
- Use Soft Clip in dual-clip region to limit weight while preserving gradients.
@@ -995,71 +996,43 @@ def compute_policy_loss_adc(
995996

996997
assert config is not None
997998
assert not isinstance(config, AlgoConfig)
998-
clip_ratio = config.clip_ratio
999-
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
1000-
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
999+
clip_ratio_low = config.clip_ratio_low
1000+
clip_ratio_high = config.clip_ratio_hig
10011001
clip_ratio_c = config.get("clip_ratio_c", 3.0)
1002+
token_entropy_quantile = config.get("token_entropy_quantile", 0.8)
1003+
masked_entropy = torch.where(response_mask.bool(), entropy.detach(), torch.nan) # (bsz, response_length)
1004+
q80 = torch.nanquantile(masked_entropy, q=token_entropy_quantile, dim=-1, keepdim=True) # (bsz, 1)
1005+
high_entropy_mask = (masked_entropy <= q80) & response_mask # only low entropy token is True
10021006

1003-
cliprange = clip_ratio
1004-
cliprange_low = clip_ratio_low
1005-
cliprange_high = clip_ratio_high
1006-
1007-
assert clip_ratio_c > 1.0, (
1008-
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}."
1009-
)
1010-
1011-
negative_approx_kl = log_prob - old_log_prob
1012-
# Clamp for stability
1013-
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
1014-
1015-
# Standard ratio
1016-
ratio = torch.exp(negative_approx_kl)
1017-
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
1018-
1019-
# For A>0: invert IS (use 1 / ratio). For A<=0: use standard ratio.
1020-
ratio_adc = torch.where(advantages > 0, 1.0 / ratio, ratio)
1021-
1022-
# Standard PPO loss (base branch for gradient direction)
1023-
pg_losses1 = -advantages * ratio_adc
1024-
1025-
if cliprange_low is None:
1026-
cliprange_low = cliprange
1027-
if cliprange_high is None:
1028-
cliprange_high = cliprange
1029-
1030-
pos_mask = advantages > 0
1031-
neg_mask = advantages < 0
1007+
ratio = torch.exp(torch.clamp(log_prob - old_log_prob, min=-20.0, max=20.0))
10321008

1033-
# Standard PPO clip (Hard Clip) on the ratio driving gradients
1034-
pg_losses2 = -advantages * torch.clamp(ratio_adc, 1 - cliprange_low, 1 + cliprange_high)
1035-
clip_pg_losses_base = torch.maximum(pg_losses1, pg_losses2)
1009+
negative_clip_ratio = torch.where(high_entropy_mask, torch.clamp(ratio, min=1-clip_ratio_low, max=None), torch.clamp(ratio, min=1-clip_ratio_high, max=None))
1010+
positive_clip_ratio = torch.where(high_entropy_mask, torch.clamp(ratio, min=None, max=1+clip_ratio_low), torch.clamp(ratio, min=None, max=1+clip_ratio_high))
10361011

1037-
pg_losses_dual = -advantages * clip_ratio_c
1012+
clip_ratio = torch.where(advantages < 0, negative_clip_ratio, positive_clip_ratio)
10381013

1039-
# Apply asymmetric dual-clip selection:
1040-
# - adv > 0: cap magnitude from below via min(base, -A * clip_ratio_c)
1041-
# - adv < 0: cap magnitude from above via max(base, -A * clip_ratio_c)
1042-
pg_losses_pos = torch.maximum(clip_pg_losses_base, pg_losses_dual)
1043-
pg_losses_neg = torch.minimum(clip_pg_losses_base, pg_losses_dual)
1044-
pg_losses = torch.where(pos_mask, pg_losses_pos, torch.where(neg_mask, pg_losses_neg, clip_pg_losses_base))
1014+
pg_clipfrac_upper = verl_F.masked_mean(torch.gt(ratio, clip_ratio).float(), response_mask)
1015+
pg_clipfrac_lower = verl_F.masked_mean(torch.lt(ratio, clip_ratio).float(), response_mask)
10451016

1046-
# Metrics
1047-
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
1017+
negative_pg_losses_clip = -advantages * negative_clip_ratio
1018+
positive_pg_losses_clip = -advantages * (positive_clip_ratio / positive_clip_ratio.detach()) / positive_clip_ratio.detach()
10481019

1049-
# Dual-clip trigger fraction (both sides), measured by ratio exceeding clip_ratio_c
1050-
# For adv>0 we monitor the inverted ratio (ratio_adc); for adv<0 we monitor the standard ratio.
1051-
lower_clip_pos = pos_mask & (ratio_adc > clip_ratio_c)
1052-
lower_clip_neg = neg_mask & (ratio > clip_ratio_c)
1053-
pg_clipfrac_lower = verl_F.masked_mean((lower_clip_pos | lower_clip_neg).float(), response_mask)
1020+
negative_dual_clip_ratio = torch.clamp(negative_clip_ratio, min=None, max=clip_ratio_c)
1021+
negative_clipped_mask = torch.gt(negative_clip_ratio, negative_dual_clip_ratio)
1022+
negative_pg_clipfrac_dual = verl_F.masked_mean(negative_clipped_mask.float(), response_mask & (advantages < 0))
1023+
negative_pg_losses_dual = -advantages * negative_dual_clip_ratio.detach() * log_prob
1024+
negative_pg_losses = torch.where(negative_clipped_mask, negative_pg_losses_dual, negative_pg_losses_clip)
10541025

1055-
if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
1056-
tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
1057-
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
1058-
pg_losses = pg_losses * tis_imp_ratio
1026+
positive_dual_clip_ratio = torch.clamp(1/positive_clip_ratio, min=None, max=clip_ratio_c)
1027+
positive_clipped_mask = torch.gt(1/positive_clip_ratio, positive_dual_clip_ratio)
1028+
positive_pg_clipfrac_dual = verl_F.masked_mean(positive_clipped_mask.float(), response_mask & (advantages > 0))
1029+
positive_pg_losses_dual = -advantages * positive_dual_clip_ratio.detach() * log_prob
1030+
positive_pg_losses = torch.where(positive_clipped_mask, positive_pg_losses_dual, positive_pg_losses_clip)
10591031

1032+
pg_losses = torch.where(advantages < 0, negative_pg_losses, positive_pg_losses)
10601033
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
10611034

1062-
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
1035+
return pg_loss, pg_clipfrac_upper, pg_clipfrac_lower, negative_pg_clipfrac_dual, positive_pg_clipfrac_dual
10631036

10641037
@register_policy_loss("gspo")
10651038
def compute_policy_loss_gspo(

verl/workers/actor/dp_actor.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import verl.utils.torch_functional as verl_F
2929
from verl import DataProto
30-
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
30+
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty, compute_policy_loss_archer
3131
from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input
3232
from verl.utils.device import get_device_id, get_device_name
3333
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
@@ -427,7 +427,9 @@ def update_policy(self, data: DataProto):
427427
loss_scale_factor = 1 / self.gradient_accumulation
428428

429429
# all return: (bsz, response_length)
430-
calculate_entropy = False
430+
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
431+
is_archer = (loss_mode == "archer")
432+
calculate_entropy = is_archer
431433
if entropy_coeff != 0:
432434
calculate_entropy = True
433435
entropy, log_prob, aux_loss = self._forward_micro_batch(
@@ -439,20 +441,31 @@ def update_policy(self, data: DataProto):
439441
else:
440442
old_log_prob = model_inputs["old_log_probs"]
441443

442-
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
443444
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
444445
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
445446
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
446-
policy_loss_fn = get_policy_loss_fn(loss_mode)
447-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
448-
old_log_prob=old_log_prob,
449-
log_prob=log_prob,
450-
advantages=advantages,
451-
response_mask=response_mask,
452-
loss_agg_mode=loss_agg_mode,
453-
config=self.config,
454-
rollout_log_probs=rollout_log_probs,
455-
)
447+
if is_archer:
448+
pg_loss, pg_clipfrac_upper, pg_clipfrac_lower, negative_pg_clipfrac_dual, positive_pg_clipfrac_dual = compute_policy_loss_archer(
449+
old_log_prob=old_log_prob,
450+
log_prob=log_prob,
451+
advantages=advantages,
452+
response_mask=response_mask,
453+
loss_agg_mode=loss_agg_mode,
454+
config=self.config,
455+
rollout_log_probs=rollout_log_probs,
456+
entropy=entropy,
457+
)
458+
else:
459+
policy_loss_fn = get_policy_loss_fn(loss_mode)
460+
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
461+
old_log_prob=old_log_prob,
462+
log_prob=log_prob,
463+
advantages=advantages,
464+
response_mask=response_mask,
465+
loss_agg_mode=loss_agg_mode,
466+
config=self.config,
467+
rollout_log_probs=rollout_log_probs,
468+
)
456469

457470
if entropy_coeff != 0:
458471
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
@@ -491,15 +504,26 @@ def update_policy(self, data: DataProto):
491504
else:
492505
loss = policy_loss * loss_scale_factor
493506
loss.backward()
494-
495-
micro_batch_metrics.update(
496-
{
497-
"actor/pg_loss": pg_loss.detach().item() * loss_scale_factor,
498-
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
499-
"actor/ppo_kl": ppo_kl.detach().item(),
500-
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
501-
}
502-
)
507+
508+
if is_archer:
509+
micro_batch_metrics.update(
510+
{
511+
"actor/pg_loss": pg_loss.detach().item() * loss_scale_factor,
512+
"actor/pg_clipfrac_upper": pg_clipfrac_upper.detach().item(),
513+
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
514+
"actor/negative_pg_clipfrac_dual": negative_pg_clipfrac_dual.detach().item(),
515+
"actor/positive_pg_clipfrac_dual": positive_pg_clipfrac_dual.detach().item(),
516+
}
517+
)
518+
else:
519+
micro_batch_metrics.update(
520+
{
521+
"actor/pg_loss": pg_loss.detach().item() * loss_scale_factor,
522+
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
523+
"actor/ppo_kl": ppo_kl.detach().item(),
524+
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
525+
}
526+
)
503527
append_to_dict(metrics, micro_batch_metrics)
504528

505529
grad_norm = self._optimizer_step()

0 commit comments

Comments
 (0)