Skip to content

feat: Implement Two-Sided Clipping for GRPO Trainer #3434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 208 additions & 11 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,19 +1031,39 @@ def test_training_no_scale_rewards(self):
def test_training_with_mask_truncated_completions(self, mock_generate):
"""Test that training works with mask_truncated_completions=True parameter."""

# Initialize tokenizer locally for this test
model_id_for_tokenizer = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id_for_tokenizer)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# We mock the generate method because the model's random weights make it extremely unlikely to produce a
# sequence containing the EOS token within the allowed max_completion_length. As a result, all tokens are
# masked in the loss, the model doesn't update, and the final check (which verifies the update) fails.
def fake_generate(prompt_ids, **kwargs):
batch_size = prompt_ids.shape[0]
max_completion_length = kwargs.get(
"generation_config"
).max_new_tokens # Get max_new_tokens from GenerationConfig

# pad_token_id = 151643; eos_token_id = 151645
completions_ids = torch.tensor(
base_completions = torch.tensor(
[
[1, 2, 3, 4, 5, 6, 7, 8], # this one is truncated
[9, 10, 11, 151645, 151643, 151643, 151643, 151643], # this one contains eos
[12, 13, 14, 15, 16, 17, 18, 151645], # particular case, eos is generated just within the limit
[1] * max_completion_length, # Truncated example
[9] * (max_completion_length // 2)
+ [tokenizer.eos_token_id]
+ [tokenizer.pad_token_id]
* (max_completion_length - max_completion_length // 2 - 1), # EOS example
[12] * (max_completion_length - 1) + [tokenizer.eos_token_id], # EOS at the end example
],
device=prompt_ids.device,
dtype=torch.long,
)
# Repeat the base completions to match the required batch size
completions_ids = base_completions.repeat(
(batch_size + base_completions.shape[0] - 1) // base_completions.shape[0], 1
)[:batch_size] # Ensure correct batch size

return torch.cat([prompt_ids, completions_ids], dim=1)

mock_generate.side_effect = fake_generate
Expand All @@ -1054,9 +1074,9 @@ def fake_generate(prompt_ids, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
per_device_train_batch_size=3, # keep batch size consistent with mock data pattern if possible
num_generations=3, # kept for consistency with mock data pattern
max_completion_length=8, # should match mock data length
mask_truncated_completions=True, # Enable masking of truncated completions
report_to="none",
)
Expand All @@ -1065,15 +1085,13 @@ def fake_generate(prompt_ids, **kwargs):
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
processing_class=tokenizer, # Use local tokenizer
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
# Check if the model parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
Expand Down Expand Up @@ -1147,3 +1165,182 @@ def test_training_num_generations_larger_than_batch_size(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_two_sided_clipping_loss(self):
"""
Tests the two-sided GRPO clipping logic with specific scenarios.
Uses a completion length of 2 to ensure logp and loss masking work.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Using a minimal dataset, actual content won't matter due to mocking
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train[:1]")

config = GRPOConfig(
output_dir=tmp_dir,
epsilon=0.2,
delta=2.0, # delta > 1 + epsilon (1.2)
epsilon_high=0.2, # Corresponds to self.epsilon_high in trainer
beta=0.0, # No KL divergence for this specific loss test
loss_type="bnpo", # Default, simplifies loss for single token
max_completion_length=2, # Use completion length 2
report_to="none",
)

trainer = GRPOTrainer(
model=model_id, # Use local model_id
args=config,
reward_funcs=lambda x, **y: [1.0] * len(x), # Dummy reward always returning 1.0
train_dataset=dataset,
processing_class=tokenizer, # Use local tokenizer
)

# Mock necessary components to isolate loss calculation
# Mock data for completion_length = 2. Logps/Advantages apply to first generated token (index 1).
mock_inputs = {
"prompt_ids": torch.tensor([[101]], device=trainer.accelerator.device),
"prompt_mask": torch.tensor([[1]], device=trainer.accelerator.device),
# Shape: (batch_size, seq_len=2)
"completion_ids": torch.tensor([[2000, 2001]], device=trainer.accelerator.device),
"completion_mask": torch.tensor(
[[1, 1]], device=trainer.accelerator.device
), # Mask for generated tokens
# Logps/Advantages shape: (batch_size, seq_len=1) - corresponding to the token at index 1
# --- Test Scenarios --- # Notes refer to the per_token_loss for the *single* token being considered.
# Scenario 1: Positive Advantage, ratio within clip bounds (1-eps, 1+eps) -> loss = -2.2
"scenario1_advantages": torch.tensor([2.0], device=trainer.accelerator.device), # Shape (1,)
"scenario1_old_logps": torch.log(
torch.tensor([[0.5]], device=trainer.accelerator.device)
), # Shape (1, 1)
"scenario1_new_logps": torch.log(
torch.tensor([[0.55]], device=trainer.accelerator.device)
), # Shape (1, 1), ratio = 1.1
# Scenario 2: Positive Advantage, ratio above 1+eps -> loss = -2.4
"scenario2_advantages": torch.tensor([2.0], device=trainer.accelerator.device),
"scenario2_old_logps": torch.log(torch.tensor([[0.4]], device=trainer.accelerator.device)),
"scenario2_new_logps": torch.log(
torch.tensor([[0.6]], device=trainer.accelerator.device)
), # ratio = 1.5
# Scenario 3: Negative Advantage, ratio within clip bounds (1-eps, 1+eps) -> loss = 1.8
"scenario3_advantages": torch.tensor([-2.0], device=trainer.accelerator.device),
"scenario3_old_logps": torch.log(torch.tensor([[0.5]], device=trainer.accelerator.device)),
"scenario3_new_logps": torch.log(
torch.tensor([[0.45]], device=trainer.accelerator.device)
), # ratio = 0.9
# Scenario 4: Negative Advantage, ratio below 1-eps -> loss = 1.6
"scenario4_advantages": torch.tensor([-2.0], device=trainer.accelerator.device),
"scenario4_old_logps": torch.log(torch.tensor([[0.5]], device=trainer.accelerator.device)),
"scenario4_new_logps": torch.log(
torch.tensor([[0.35]], device=trainer.accelerator.device)
), # ratio = 0.7
# Scenario 5: Negative Advantage, ratio above delta (2.0) -> loss = 4.0
"scenario5_advantages": torch.tensor([-2.0], device=trainer.accelerator.device),
"scenario5_old_logps": torch.log(torch.tensor([[0.2]], device=trainer.accelerator.device)),
"scenario5_new_logps": torch.log(
torch.tensor([[0.5]], device=trainer.accelerator.device)
), # ratio = 2.5
# Scenario 6: Negative Advantage, ratio between 1+eps and delta -> loss = 3.0
"scenario6_advantages": torch.tensor([-2.0], device=trainer.accelerator.device),
"scenario6_old_logps": torch.log(torch.tensor([[0.4]], device=trainer.accelerator.device)),
"scenario6_new_logps": torch.log(
torch.tensor([[0.6]], device=trainer.accelerator.device)
), # ratio = 1.5
}

# Mock _get_per_token_logps to return predefined values based on scenario
def mock_get_logps_side_effect(model, input_ids, attention_mask, logits_to_keep, batch_size=None):
# The actual return value is controlled by setting .return_value before each call
# This side_effect is just a placeholder structure
if model == trainer.model:
# This function is expected to return shape (batch_size, logits_to_keep)
# In our case, logits_to_keep = completion_ids.shape[1] - 1 = 2 - 1 = 1
# So, shape should be (1, 1)
return torch.zeros(
(input_ids.shape[0], logits_to_keep), device=trainer.accelerator.device
) # Placeholder
return torch.zeros_like(
input_ids[:, 1:], dtype=torch.float, device=input_ids.device
) # Match expected output shape

with patch.object(trainer, "_get_per_token_logps") as mock_logps_func: # Removed side_effect here
# --- Run Scenarios --- #
base_inputs = { # Inputs common to all scenarios
"prompt_ids": mock_inputs["prompt_ids"],
"prompt_mask": mock_inputs["prompt_mask"],
"completion_ids": mock_inputs["completion_ids"],
"completion_mask": mock_inputs["completion_mask"],
}

# Scenario 1
inputs_sc1 = {
**base_inputs,
"advantages": mock_inputs["scenario1_advantages"],
"old_per_token_logps": mock_inputs["scenario1_old_logps"],
}
mock_logps_func.return_value = mock_inputs["scenario1_new_logps"] # Set mock return for this call
loss1 = trainer.compute_loss(trainer.model, inputs_sc1)
self.assertAlmostEqual(loss1.item(), -2.2, delta=1e-5, msg="Scenario 1 Failed")

# Scenario 2
inputs_sc2 = {
**base_inputs,
"advantages": mock_inputs["scenario2_advantages"],
"old_per_token_logps": mock_inputs["scenario2_old_logps"],
}
mock_logps_func.return_value = mock_inputs["scenario2_new_logps"]
loss2 = trainer.compute_loss(trainer.model, inputs_sc2)
self.assertAlmostEqual(
loss2.item(), -2.4, delta=1e-5, msg="Scenario 2 Failed"
) # Loss = -min(3.0, 2.4) = -2.4

# Scenario 3
inputs_sc3 = {
**base_inputs,
"advantages": mock_inputs["scenario3_advantages"],
"old_per_token_logps": mock_inputs["scenario3_old_logps"],
}
mock_logps_func.return_value = mock_inputs["scenario3_new_logps"]
loss3 = trainer.compute_loss(trainer.model, inputs_sc3)
self.assertAlmostEqual(
loss3.item(), 1.8, delta=1e-5, msg="Scenario 3 Failed"
) # Loss = -min(-1.8, -1.8) = 1.8

# Scenario 4
inputs_sc4 = {
**base_inputs,
"advantages": mock_inputs["scenario4_advantages"],
"old_per_token_logps": mock_inputs["scenario4_old_logps"],
}
mock_logps_func.return_value = mock_inputs["scenario4_new_logps"]
loss4 = trainer.compute_loss(trainer.model, inputs_sc4)
self.assertAlmostEqual(
loss4.item(), 1.6, delta=1e-5, msg="Scenario 4 Failed"
) # Loss = -min(-1.4, -1.6) = 1.6

# Scenario 5
inputs_sc5 = {
**base_inputs,
"advantages": mock_inputs["scenario5_advantages"],
"old_per_token_logps": mock_inputs["scenario5_old_logps"],
}
mock_logps_func.return_value = mock_inputs["scenario5_new_logps"]
loss5 = trainer.compute_loss(trainer.model, inputs_sc5)
self.assertAlmostEqual(
loss5.item(), 4.0, delta=1e-5, msg="Scenario 5 Failed"
) # Loss = -min(-4.0, -2.4) = 4.0

# Scenario 6
inputs_sc6 = {
**base_inputs,
"advantages": mock_inputs["scenario6_advantages"],
"old_per_token_logps": mock_inputs["scenario6_old_logps"],
}
mock_logps_func.return_value = mock_inputs["scenario6_new_logps"]
loss6 = trainer.compute_loss(trainer.model, inputs_sc6)
self.assertAlmostEqual(
loss6.item(), 3.0, delta=1e-5, msg="Scenario 6 Failed"
) # Loss = -min(-3.0, -2.4) = 3.0
10 changes: 10 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class GRPOConfig(TrainingArguments):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping.
delta: (`float`, *optional*, defaults to `None`):
Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon.
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
Expand Down Expand Up @@ -389,6 +391,12 @@ class GRPOConfig(TrainingArguments):
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
delta: Optional[float] = field(
default=None,
metadata={
"help": "If set to a float value (e.g., 2.0), enables the upper clipping bound in two-sided GRPO loss. If None (default), the standard GRPO clipping is used. Recommended to be > 1 + epsilon when enabled."
},
)
epsilon_high: Optional[float] = field(
default=None,
metadata={
Expand Down Expand Up @@ -536,3 +544,5 @@ def __post_init__(self):
"current global eval batch size, the valid values for the number of generations are: "
f"{possible_values}."
)
if self.delta is not None and self.use_liger_loss:
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")
10 changes: 9 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,15 @@ def _compute_loss(self, model, inputs):
)
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)

# Conditionally apply the upper clipping bound (delta) if specified
if self.args.delta is not None:
# Use clamp instead of min to handle tensor-float comparison
per_token_loss1 = torch.clamp(coef_1, max=self.args.delta) * advantages.unsqueeze(1)
else:
# Original GRPO clipping (only lower bound implicitly applied by the final min)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)

per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
Expand Down