-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[GRPO] Adds an option to scale the loss by a constant factor #3231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -101,6 +101,8 @@ class GRPOConfig(TrainingArguments): | |||||
speed, but may be numerically unstable for long training runs. | ||||||
num_iterations (`int`, *optional*, defaults to `1`): | ||||||
Number of iterations per batch (denoted as μ in the algorithm). | ||||||
use_max_tokens_norm (`bool`, *optional*, defaults to `False`): | ||||||
Whether to use the max tokens norm. If `True`, the loss is normalized by a consant, the maximum possible number of tokens | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion to clarify what we mean by "maximum possible"
Suggested change
|
||||||
epsilon (`float`, *optional*, defaults to `0.2`): | ||||||
Epsilon value for clipping. | ||||||
epsilon_high (`float` or `None`, *optional*, defaults to `None`): | ||||||
|
@@ -275,6 +277,13 @@ class GRPOConfig(TrainingArguments): | |||||
default=1, | ||||||
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."}, | ||||||
) | ||||||
use_max_tokens_norm: bool = field( | ||||||
default=False, | ||||||
metadata={ | ||||||
"help": "Whether to use the max tokens norm. If `True`, the loss is normalized by a constant, the maximum " | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto here if you agree with the change above |
||||||
"possible number of tokens." | ||||||
}, | ||||||
) | ||||||
epsilon: float = field( | ||||||
default=0.2, | ||||||
metadata={"help": "Epsilon value for clipping."}, | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -409,6 +409,15 @@ def data_collator(features): # No data collation is needed in GRPO | |
self.use_vllm = args.use_vllm | ||
self.use_liger_loss = args.use_liger_loss | ||
|
||
self.use_max_tokens_norm = args.use_max_tokens_norm | ||
if self.use_max_tokens_norm: | ||
if self.use_liger_loss: | ||
raise ValueError("`use_max_tokens_norm` is not supported with `liger_loss`.") | ||
# calculate a constant factor to normalize the loss | ||
self.max_tokens_norm = args.per_device_train_batch_size * ( | ||
args.max_prompt_length + args.max_completion_length | ||
) | ||
|
||
# Multi-step | ||
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper | ||
self.epsilon_low = args.epsilon | ||
|
@@ -1072,7 +1081,11 @@ def _compute_loss(self, model, inputs): | |
per_token_loss = -torch.min(per_token_loss1, per_token_loss2) | ||
if self.beta != 0.0: | ||
per_token_loss = per_token_loss + self.beta * per_token_kl | ||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() | ||
|
||
if self.use_max_tokens_norm: | ||
loss = (per_token_loss * completion_mask).sum() / self.max_tokens_norm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how easy it is to unit test this, but would it make sense to do it so that we're sure the loss is being computed as your diagrams show? E.g. an integration test would be to check that specifying the config params gives the expected scaling for some dummy inputs |
||
else: | ||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() | ||
|
||
# Log the metrics | ||
mode = "eval" if self.control.should_evaluate else "train" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is the loss proposed in Dr GRPO, correct?
If so, I think it should be explicitly mentioned in the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure actually, I thought that was our current implementation. I will take another look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we use a modified version of DAPO where we normalize per local batch (and not per group).
It the above figure, we use something between BNPO (hard to implement with grad accum) and DAPO