Skip to content

[GRPO] Fix loss normalization #2881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 17, 2025
Merged

[GRPO] Fix loss normalization #2881

merged 3 commits into from
Feb 17, 2025

Conversation

edbeeching
Copy link
Collaborator

What does this PR do?

The current GRPO implementation uses per-sequence normalization, this PR corrects this to be global normalization

Details:
In Causal Language Modelling, we typically use global normalization to scale the loss, so that each unmasked token's loss provides the same contribution to the total loss. Example from transformers codebase: https://github.com/huggingface/transformers/blob/fae0f3dde83b7a54441f7a5bb0fc45d354fe81ce/src/transformers/loss/loss_utils.py#L24-L29

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@edbeeching edbeeching merged commit 293b620 into main Feb 17, 2025
14 checks passed
@edbeeching edbeeching deleted the fix-grpo-loss-normalization branch February 17, 2025 12:26
@kashif kashif mentioned this pull request Feb 18, 2025
3 tasks
@BramVanroy
Copy link
Contributor

As seen on Twitter, some discussion about this change: https://x.com/danielhanchen/status/1900844864134410695

@gameofdimension
Copy link

should we make it configurable?

yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
* fix GRPO loss normalization

* fix sum dim

* fix loss= repeated
tyler-griggs added a commit to NovaSky-AI/SkyRL that referenced this pull request Jul 15, 2025
## What does this PR do?
Adds support for token-level loss (ie, `token_mean` loss reduction type)
as introduced by DAPO.

With `token_mean` loss reduction, all tokens in all sequences contribute
equally to loss.

The loss reduction type is configurable via
`trainer.algorithm.loss_reduction`, but the default is updated to be
`token_mean`, as opposed to our previous implementation
(`sequence_mean`). This loss reduction is what the community is
standardizing on as default (TRL's
[default](huggingface/trl#2881), verl's
[default](https://github.com/volcengine/verl/blob/517cc23c9dbb0da5c2cd2b012466790e29cb781a/verl/trainer/config/actor/actor.yaml#L63))

Wandb report of comparing `token_mean` vs `sequence_mean`:
https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k/reports/Token-level-loss-token_mean---VmlldzoxMzYwMDc4MQ

The only plot with a notable difference is `policy_loss`, which is much
larger for `token_mean` than it is for `sequence_mean`:
<img width="312" height="274" alt="Screenshot 2025-07-15 at 9 52 57 AM"
src="https://github.com/user-attachments/assets/40f94cb6-c5e5-47f6-9b09-a076811746a0"
/>

However, this `policy_loss` matches the same magnitude of `pg_loss` we
observe in verl:
<img width="980" height="611" alt="Screenshot 2025-07-15 at 9 54 39 AM"
src="https://github.com/user-attachments/assets/53714573-2b21-4e67-b30a-dd3648279438"
/>

---------

Co-authored-by: Sumanth R Hegde <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants