Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class GRPOConfig(TrainingArguments):
speed, but may be numerically unstable for long training runs.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
use_max_tokens_norm (`bool`, *optional*, defaults to `False`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the loss proposed in Dr GRPO, correct?
If so, I think it should be explicitly mentioned in the doc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure actually, I thought that was our current implementation. I will take another look.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we use a modified version of DAPO where we normalize per local batch (and not per group).

Gm-URXOagAABJH3

It the above figure, we use something between BNPO (hard to implement with grad accum) and DAPO

Whether to use the max tokens norm. If `True`, the loss is normalized by a consant, the maximum possible number of tokens
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion to clarify what we mean by "maximum possible"

Suggested change
Whether to use the max tokens norm. If `True`, the loss is normalized by a consant, the maximum possible number of tokens
Whether to use the max tokens norm. If `True`, the loss is normalized by a constant factor that is determined by the total number of prompt and completions tokens in a batch.

epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping.
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -275,6 +277,13 @@ class GRPOConfig(TrainingArguments):
default=1,
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
)
use_max_tokens_norm: bool = field(
default=False,
metadata={
"help": "Whether to use the max tokens norm. If `True`, the loss is normalized by a constant, the maximum "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto here if you agree with the change above

"possible number of tokens."
},
)
epsilon: float = field(
default=0.2,
metadata={"help": "Epsilon value for clipping."},
Expand Down
15 changes: 14 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,15 @@ def data_collator(features): # No data collation is needed in GRPO
self.use_vllm = args.use_vllm
self.use_liger_loss = args.use_liger_loss

self.use_max_tokens_norm = args.use_max_tokens_norm
if self.use_max_tokens_norm:
if self.use_liger_loss:
raise ValueError("`use_max_tokens_norm` is not supported with `liger_loss`.")
# calculate a constant factor to normalize the loss
self.max_tokens_norm = args.per_device_train_batch_size * (
args.max_prompt_length + args.max_completion_length
)

# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
Expand Down Expand Up @@ -1072,7 +1081,11 @@ def _compute_loss(self, model, inputs):
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

if self.use_max_tokens_norm:
loss = (per_token_loss * completion_mask).sum() / self.max_tokens_norm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how easy it is to unit test this, but would it make sense to do it so that we're sure the loss is being computed as your diagrams show?

E.g. an integration test would be to check that specifying the config params gives the expected scaling for some dummy inputs

else:
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

# Log the metrics
mode = "eval" if self.control.should_evaluate else "train"
Expand Down
Loading