Skip to content

🏗️ Refactor top-entropy in GRPO #3727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 19, 2025
Merged
82 changes: 48 additions & 34 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.utils import is_peft_available

from trl import GRPOConfig, GRPOTrainer
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict
from trl.trainer.grpo_trainer import RepeatSampler, get_high_entropy_mask, shuffle_tensor_dict, split_tensor_dict

from .testing_utils import require_vllm

Expand Down Expand Up @@ -213,6 +213,52 @@ def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
assert sampled[24:28] == sampled[28:32] == sampled[32:36]


class GetHighEntropyMaskTester(unittest.TestCase):
def test_compute_entropy_mask_0(self):
# We have a total of 12 tokens out of which 10 are non-pad.
# for a top_entropy_quantile of 0.8, we expect the top 20% i.e 2 non-pad tokens corresponding to
# the highest entropy to be unmasked.
# In our example these will be the tokens corresponding to the entropies 0.9 and 1.0 since 1.1 and 1.2 are pad
# tokens they are excluded from the entropy threshold calculation.
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.8)
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0]], dtype=torch.bool)
torch.testing.assert_close(entropy_mask, expected_mask)

def test_compute_entropy_mask_1(self):
# Another example with a different set of entropies and a different mask.
entropies = torch.tensor([[0.1, 0.2, 0.3, 1.4, 0.5, 0.14], [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]])
mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]])
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.8)
expected_mask = torch.tensor([[0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool)
torch.testing.assert_close(entropy_mask, expected_mask)

def test_compute_entropy_mask_lower_threshold(self):
# For a threshold of 0.5 we expect the top half of the non-pad tokens to be unmasked.
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.5)
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 1], [1, 1, 1, 1, 0, 0]], dtype=torch.bool)
torch.testing.assert_close(entropy_mask, expected_mask)

def test_compute_entropy_mask_all_tokens(self):
# For a threshold of 0.0 we expect all non-pad tokens to be unmasked.
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.0)
expected_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]], dtype=torch.bool)
torch.testing.assert_close(entropy_mask, expected_mask)

def test_compute_entropy_mask_no_tokens(self):
# If there are no non-pad tokens we expect the mask to be all zeros BUT ONE VALUE.
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=1.0)
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool)
torch.testing.assert_close(entropy_mask, expected_mask)


class GRPOTrainerTester(unittest.TestCase):
def test_init_minimal(self):
# Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset
Expand Down Expand Up @@ -850,7 +896,7 @@ def test_training_with_entropy_filter(self):
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
token_entropy_percentile_threshold=0.8,
top_entropy_quantile=0.2,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
Expand Down Expand Up @@ -1297,35 +1343,3 @@ def reward_func(completions, **kwargs):
train_dataset=dataset,
)
trainer.train()

def test_compute_entropy_mask(self):
"""Test the _compute_entropy_mask method."""
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=GRPOConfig(token_entropy_percentile_threshold=0.8),
)

# Create dummy entropies and completion mask
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
completion_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])

entropy_mask = trainer._compute_entropy_mask(entropies, completion_mask)

self.assertEqual(entropy_mask.shape, entropies.shape)

# We have a total of 12 tokens out of which 10 are non-pad,
# for a token_entropy_percentile_threshold of 0.8,
# we expect the top 20% i.e 2 non-pad tokens corresponding to the highest entropy to be unmasked.
# In our example these will be the tokens corresponding to the entropies 0.9 and 1.0
# since 1.1 and 1.2 are pad tokens they are excluded from the entropy threshold calculation.
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0]], dtype=torch.bool)
self.assertTrue(torch.equal(entropy_mask, expected_mask))

entropies = torch.tensor([[0.1, 0.2, 0.3, 1.4, 0.5, 0.14], [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]])
completion_mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]])

expected_mask = torch.tensor([[0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool)
entropy_mask = trainer._compute_entropy_mask(entropies, completion_mask)

self.assertTrue(torch.equal(entropy_mask, expected_mask))
24 changes: 13 additions & 11 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ class GRPOConfig(TrainingArguments):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
token_entropy_percentile_threshold (`float`, *optional*, defaults to `0.0`):
τ parameter from the [Beyond the 80/20 Rule](https://huggingface/papers/2506.01939) paper, which finds that
masking out the bottom τ percentile of tokens based on the entropy of the probability distribution at a
given sequence position, in the policy loss term yields better results. The range of the parameter is
[0.0-1.0] a value of 0.0 means that none the tokens are filtered out and 1.0 means that all the tokens are
masked. Recommended value is `0.8`.
top_entropy_quantile (`float`, *optional*, defaults to `1.0`):
ρ parameter from [Beyond the 80/20 Rule](https://huggingface.co/papers/2506.01939). Keeps in the policy
loss term only the top-ρ quantile of tokens by entropy of the probability distribution at each sequence
position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token;
`1.0` keeps all tokens. The paper recommends a value of `0.2`.
If used with `mask_truncated_completions=True`, only tokens from non-truncated completions are considered.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use the Liger GRPO loss.

Expand Down Expand Up @@ -520,12 +520,14 @@ class GRPOConfig(TrainingArguments):
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
},
)
token_entropy_percentile_threshold: float = field(
default=0.0,
top_entropy_quantile: float = field(
default=1.0,
metadata={
"help": "Percentile threshold for filtering out tokens in the policy loss based on entropy."
"Positions in the completion with entropy below this percentile are masked out."
"0.8 is the recommended value if you'd like to enable entropy based masking."
"help": "ρ parameter from Beyond the 80/20 Rule. Keeps in the policy loss term only the top-ρ quantile of "
"tokens by entropy of the probability distribution at each sequence position, improving results. Range: "
"[0.0-1.0]. A value of `1.0` masks all but the highest entropy token; `0.0` keeps all tokens. The paper "
"recommends a value of `0.2`. If used with `mask_truncated_completions=True`, only tokens from "
"non-truncated completions are considered."
},
)
use_liger_loss: bool = field(
Expand Down
69 changes: 40 additions & 29 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,30 @@ def identity(x):
return x


def get_high_entropy_mask(entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor:
"""
Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold.

Args:
entropies (`torch.Tensor`):
Tensor of shape (batch_size, seq_len) with per-token entropy values.
mask (`torch.Tensor`):
Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding.
threshold (`float`):
Quantile threshold between `0.0` and `1.0` to select high-entropy tokens.

Returns:
`torch.Tensor`:
Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold and
`False` otherwise.
"""
non_pad_entropies = entropies[mask.bool()].float()
entropy_threshold = torch.quantile(non_pad_entropies, threshold)
masked_entropies = entropies * mask.float()
entropy_mask = masked_entropies >= entropy_threshold
return entropy_mask & mask.bool() # ensure padding tokens are always masked out


class GRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
Expand Down Expand Up @@ -523,8 +547,8 @@ def __init__(
self.loss_type = args.loss_type
self.scale_rewards = args.scale_rewards
self.mask_truncated_completions = args.mask_truncated_completions
self.token_entropy_percentile_threshold = args.token_entropy_percentile_threshold
if self.use_liger_loss and self.token_entropy_percentile_threshold > 0.0:
self.top_entropy_quantile = args.top_entropy_quantile
if self.use_liger_loss and self.top_entropy_quantile < 1.0:
raise NotImplementedError(
"Liger Kernels don't currently support masking token positions based on entropy."
)
Expand Down Expand Up @@ -908,7 +932,7 @@ def _get_per_token_logps_and_entropies(

logps = torch.cat(all_logps, dim=0)
entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None
return {"logps": logps, "entropies": entropies}
return logps, entropies

def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
Expand Down Expand Up @@ -1288,23 +1312,23 @@ def _generate_and_score_completions(
# old_per_token_logps == per_token_logps, so we can skip it's computation here, and use
# per_token_logps.detach() instead.
if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
old_per_token_logps = self._get_per_token_logps_and_entropies(
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size
)["logps"]
)
else:
old_per_token_logps = None

# Compute the per-token log probabilities for the reference model
if self.beta != 0.0:
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps_and_entropies(
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)["logps"]
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps_and_entropies(
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)["logps"]
)
else:
ref_per_token_logps = None

Expand Down Expand Up @@ -1439,15 +1463,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
else:
return self._compute_loss(model, inputs)

def _compute_entropy_mask(self, entropies, completion_mask):
# compute the entropy threshold across all tokens in the batch
non_pad_entropies = entropies[completion_mask.bool()]
# disregard pad tokens when computing the entropy threshold
entropy_threshold = torch.quantile(non_pad_entropies.float(), self.token_entropy_percentile_threshold)
entropies = entropies * completion_mask.float() # mask out the padding tokens
entropy_mask = entropies >= entropy_threshold
return entropy_mask

def _compute_loss(self, model, inputs):
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
Expand All @@ -1456,18 +1471,14 @@ def _compute_loss(self, model, inputs):
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

# Compute the entropy at each position in the completion
if self.token_entropy_percentile_threshold > 0.0:
logps_and_entropies = self._get_per_token_logps_and_entropies(
model, input_ids, attention_mask, logits_to_keep, compute_entropy=True
)
per_token_logps = logps_and_entropies["logps"]
entropies = logps_and_entropies["entropies"]
entropy_mask = self._compute_entropy_mask(entropies, completion_mask)
# Compute the per_token_logps the entropy if necessary at each position in the completion
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
model, input_ids, attention_mask, logits_to_keep, compute_entropy=self.top_entropy_quantile < 1.0
)

if self.top_entropy_quantile < 1.0:
entropy_mask = get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile)
else:
per_token_logps = self._get_per_token_logps_and_entropies(
model, input_ids, attention_mask, logits_to_keep
)["logps"]
entropy_mask = None

# Compute the KL divergence between the model and the reference model
Expand Down
Loading