-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[algo] feat: add GSPO-token policy loss computation function #2775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
4429d87
3639a1c
2a3dc1c
a0f4716
64a9a8d
7dc4a27
a3c9511
7e95f77
4b247b2
2bed736
587611b
d8bed65
5cc8837
c66295f
12d80e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
import verl.utils.torch_functional as verl_F | ||
from verl.trainer.config import AlgoConfig | ||
from verl.utils.import_utils import deprecated | ||
from verl.workers.config import ActorConfig | ||
|
||
PolicyLossFn = Callable[ | ||
[ | ||
|
@@ -886,6 +887,64 @@ def compute_policy_loss_vanilla( | |
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower | ||
|
||
|
||
@register_policy_loss("gspo") | ||
def compute_policy_loss_gspo( | ||
old_log_prob: torch.Tensor, | ||
log_prob: torch.Tensor, | ||
advantages: torch.Tensor, | ||
response_mask: torch.Tensor, | ||
loss_agg_mode: str = "token-mean", | ||
config: Optional[DictConfig | ActorConfig] = None, | ||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
""" | ||
Compute the clipped policy objective and related metrics for GSPO. | ||
|
||
See https://arxiv.org/pdf/2507.18071 for more details. | ||
|
||
Args: | ||
old_log_prob (torch.Tensor): | ||
Log-probabilities of actions under the old policy, shape (batch_size, response_length). | ||
log_prob (torch.Tensor): | ||
Log-probabilities of actions under the current policy, shape (batch_size, response_length). | ||
advantages (torch.Tensor): | ||
Advantage estimates for each action, shape (batch_size, response_length). | ||
response_mask (torch.Tensor): | ||
Mask indicating which tokens to include in the loss, shape (batch_size, response_length). | ||
loss_agg_mode (str, optional): | ||
Aggregation mode for `agg_loss`. Defaults to "token-mean". | ||
""" | ||
|
||
assert config is not None | ||
assert isinstance(config, ActorConfig) | ||
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio | ||
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio | ||
|
||
negative_approx_kl = log_prob - old_log_prob | ||
|
||
# compute sequence-level importance ratio | ||
seq_lengths = torch.sum(response_mask, dim=-1) | ||
negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths.clamp(min=1) | ||
log_seq_importance_ratio = negative_approx_kl_seq.detach().unsqueeze(-1) + log_prob - log_prob.detach() | ||
|
||
# Clamp log_seq_importance_ratio for stability | ||
log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, min=-20.0, max=20.0) | ||
seq_importance_ratio = torch.exp(log_seq_importance_ratio) | ||
|
||
pg_losses1 = -advantages * seq_importance_ratio | ||
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high) | ||
pg_losses = torch.maximum(pg_losses1, pg_losses2) | ||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here agg_loss should always be sentence-level? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, "seq-mean-token-sum" looks most suitable. But from the paper's definition, maybe a "batch-mean-group-mean-token-sum" would be more accurate? Which sums the token loss of seq, then takes a mean of group seq loss at the group-level, then takes a mean/sum again at the batch level. there is a similar discussion in #2776 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is GSPO-token, so it should be exactly seq-mean-token-mean. EDIT: hope i haven't hallucinated in my nightly calcs The example in #2776 is about token-mean vs seq-mean). Referencing the article: ![]()
in current verl implementation for seq-mean-token-mean. Let's name this
Notes:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mokeevdmitrii You are right! we should use seq-mean-token-mean when using GSPO, thanks very much! |
||
|
||
# For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO) | ||
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) | ||
pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device) | ||
|
||
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) | ||
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) | ||
|
||
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower | ||
|
||
|
||
@register_policy_loss("gpg") | ||
def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None): | ||
"""Adapted from | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should change the calculation order to avoid precision error:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @chujiezheng ! all resolved.