Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
- sections:
- local: dataset_formats
title: Dataset Formats
- local: paper_index
title: Paper Index
- local: how_to_train
title: Training FAQ
- local: logging
Expand Down
10 changes: 5 additions & 5 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ This constant is recommended to be the maximum completion length. To use this fo
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
- `clip_ratio/region_mean`: The ratio of token probabilities where the GRPO objective is clipped to stay within the trust region:
- `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region:
$$
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
$$
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
- `clip_ratio/low_mean`: The average ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/low_min`: The minimum ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/high_mean`: The average ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
- `clip_ratio/high_max`: The maximum ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
- `clip_ratio/low_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/low_min`: The minimum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/high_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
- `clip_ratio/high_max`: The maximum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).

## Customization

Expand Down
26 changes: 26 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Paper Index

<Tip warning={true}>

Section under construction. Feel free to contribute!

</Tip>

## Group Sequence Policy Optimization

**📜 Paper**: https://huggingface.co/papers/2507.18071

GSPO is a GRPO variant that computes importance sampling weights at the sequence level instead of per-token. To reproduce the paper's setting, use this configuration:

```python
from trl import GRPOConfig

training_args = GRPOConfig(
importance_sampling_level="sequence",
loss_type="grpo",
steps_per_generation=...,
beta=0.04, # not explicitly specified in the paper, but they likely used the same value as in the GRPO paper
)
```

While the original paper doesn’t specify the hyperparameters used, this modification only has an effect when training is slightly off-policy—for example, when `steps_per_generation > gradient_accumulation_steps` or `num_iterations > 1`. Otherwise, it is effectively equivalent to no modification.
32 changes: 32 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,38 @@ def test_training_vlm_and_prompt_truncation(self):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_sequence_importance_sampling(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
num_iterations=2, # the importance sampling weights won't be 0 in this case
importance_sampling_level="sequence",
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")


if __name__ == "__main__":
unittest.main()
16 changes: 16 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ class GRPOConfig(TrainingArguments):
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
importance_sampling_level (`str`, *optional*, defaults to `"token"`):
Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
log-probability ratios across valid tokens to produce a single ratio per sequence. The
[GSPO paper](https://huggingface.co/papers/2507.18071) shows that sequence-level sampling often yields more
stable training and better alignment with sequence-level rewards.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
Expand Down Expand Up @@ -458,6 +464,16 @@ class GRPOConfig(TrainingArguments):
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
},
)
importance_sampling_level: str = field(
default="token",
metadata={
"help": "Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. "
"`'token'` keeps the raw per-token log-probability ratios (one weight per token). `'sequence'` averages "
"the log-probability ratios across valid tokens to produce a single ratio per sequence. The GSPO paper "
"shows that sequence-level sampling often yields more stable training and better alignment with "
"sequence-level rewards."
},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
Expand Down
34 changes: 29 additions & 5 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,12 +670,18 @@ def __init__(
self.use_liger_loss = args.use_liger_loss
self.loss_type = args.loss_type
self.scale_rewards = args.scale_rewards
self.importance_sampling_level = args.importance_sampling_level
self.mask_truncated_completions = args.mask_truncated_completions
self.top_entropy_quantile = args.top_entropy_quantile
if self.use_liger_loss and self.top_entropy_quantile < 1.0:
raise NotImplementedError(
"Liger Kernels don't currently support masking token positions based on entropy."
)
if self.use_liger_loss and not self.importance_sampling_level == "token":
raise NotImplementedError(
"Liger Kernels currently only support token-level importance sampling. Please set"
"`importance_sampling_level` to 'token'."
)

# Datasets
self.shuffle_dataset = args.shuffle_dataset
Expand Down Expand Up @@ -1783,7 +1789,22 @@ def _compute_loss(self, model, inputs):
# (see _generate_and_score_completions) and use per_token_logps.detach() instead.
old_per_token_logps = inputs.get("old_per_token_logps")
old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps
coef_1 = torch.exp(per_token_logps - old_per_token_logps)

log_ratio = per_token_logps - old_per_token_logps
if self.importance_sampling_level == "token":
log_importance_weights = log_ratio
elif self.importance_sampling_level == "sequence":
log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
log_importance_weights = log_importance_weights.unsqueeze(-1)
else:
raise ValueError(
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
"and 'sequence'."
)
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)

coef_1 = torch.exp(log_importance_weights)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)

# Two-sided clipping
Expand Down Expand Up @@ -1813,7 +1834,10 @@ def _compute_loss(self, model, inputs):
completion_token_count = completion_mask.sum().clamp(min=1.0)

def masked_batch_mean(x):
return (x * completion_mask).sum() / completion_token_count
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
return x.mean()
else:
return (x * completion_mask).sum() / completion_token_count

if self.beta != 0.0:
mean_kl = masked_batch_mean(per_token_kl)
Expand All @@ -1827,9 +1851,9 @@ def masked_batch_mean(x):
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
is_region_clipped = is_low_clipped | is_high_clipped

low_clip = masked_batch_mean(is_low_clipped)
high_clip = masked_batch_mean(is_high_clipped)
clip_ratio = masked_batch_mean(is_region_clipped)
low_clip = masked_batch_mean(is_low_clipped.float())
high_clip = masked_batch_mean(is_high_clipped.float())
clip_ratio = masked_batch_mean(is_region_clipped.float())

gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
Expand Down
Loading