Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,3 +1147,84 @@ def test_training_num_generations_larger_than_batch_size(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@staticmethod
def _make_delta_trainer(tmp_dir, tokenizer, dataset):
"""Helper method to create a GRPOTrainer with specific delta clipping parameters."""
cfg = GRPOConfig(
output_dir=tmp_dir,
epsilon=0.20,
delta=2.0,
epsilon_high=0.20,
beta=0.0,
loss_type="bnpo",
max_completion_length=2,
report_to="none",
)
return GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=cfg,
reward_funcs=lambda x, **y: [1.0] * len(x),
train_dataset=dataset,
processing_class=tokenizer,
)

@staticmethod
def _delta_inputs(device):
"""Helper method to create standard inputs for delta clipping tests."""
return {
"prompt_ids": torch.tensor([[101]], device=device),
"prompt_mask": torch.tensor([[1]], device=device),
"completion_ids": torch.tensor([[2000, 2001]], device=device),
"completion_mask": torch.tensor([[1, 1]], device=device),
}

@parameterized.expand(
[
# name, advantage, old_prob, new_prob, expected_loss
("pos_ratio_in_clip", 2.0, 0.50, 0.55, -2.2),
("pos_ratio_above_clip", 2.0, 0.40, 0.60, -2.4),
("neg_ratio_in_clip", -2.0, 0.50, 0.45, 1.8),
("neg_ratio_below_clip", -2.0, 0.50, 0.35, 1.6),
("neg_ratio_above_delta", -2.0, 0.20, 0.50, 4.0),
("neg_between_clip_delta", -2.0, 0.40, 0.60, 3.0),
]
)
def test_two_sided_clipping_loss(self, name, advantage, old_prob, new_prob, expected_loss):
"""Test two-sided GRPO clipping logic with different scenarios.

Args:
name: Test case name
advantage: Advantage value for the scenario
old_prob: Old policy probability
new_prob: New policy probability
expected_loss: Expected loss value
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)

dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train[:1]")

trainer = self._make_delta_trainer(tmp_dir, tokenizer, dataset)

inputs = self._delta_inputs(trainer.accelerator.device)
inputs.update(
{
"advantages": torch.tensor([advantage], device=trainer.accelerator.device),
"old_per_token_logps": torch.log(torch.tensor([[old_prob]], device=trainer.accelerator.device)),
}
)

# Mock _get_per_token_logps to return predefined new log probabilities
with patch.object(trainer, "_get_per_token_logps") as mock_logps_func:
mock_logps_func.return_value = torch.log(torch.tensor([[new_prob]], device=trainer.accelerator.device))

# Compute loss and verify
loss = trainer.compute_loss(trainer.model, inputs)
self.assertAlmostEqual(
loss.item(),
expected_loss,
delta=1e-5,
msg=f"Scenario {name} failed: expected {expected_loss}, got {loss.item()}",
)
10 changes: 10 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class GRPOConfig(TrainingArguments):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping.
delta: (`float`, *optional*, defaults to `None`):
Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291).
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
Expand Down Expand Up @@ -389,6 +391,12 @@ class GRPOConfig(TrainingArguments):
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
delta: Optional[float] = field(
default=None,
metadata={
"help": "If set to a float value (e.g., 2.0), enables the upper clipping bound in two-sided GRPO loss. If None (default), the standard GRPO clipping is used. Recommended to be > 1 + epsilon when enabled."
},
)
epsilon_high: Optional[float] = field(
default=None,
metadata={
Expand Down Expand Up @@ -536,3 +544,5 @@ def __post_init__(self):
"current global eval batch size, the valid values for the number of generations are: "
f"{possible_values}."
)
if self.delta is not None and self.use_liger_loss:
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")
9 changes: 8 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,14 @@ def _compute_loss(self, model, inputs):
)
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)

if self.args.delta is not None:
# Use clamp instead of min to handle tensor-float comparison
per_token_loss1 = torch.clamp(coef_1, max=self.args.delta) * advantages.unsqueeze(1)
else:
# Original GRPO clipping (only lower bound implicitly applied by the final min)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)

per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
Expand Down
Loading