Skip to content

Commit 116ec49

Browse files
authored
🏗️ Refactor top-entropy in GRPO (#3727)
1 parent 1b17fa7 commit 116ec49

File tree

3 files changed

+111
-74
lines changed

3 files changed

+111
-74
lines changed

tests/test_grpo_trainer.py

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from transformers.utils import is_peft_available
2525

2626
from trl import GRPOConfig, GRPOTrainer
27-
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict
27+
from trl.trainer.grpo_trainer import RepeatSampler, get_high_entropy_mask, shuffle_tensor_dict, split_tensor_dict
2828

2929
from .testing_utils import require_vllm
3030

@@ -216,6 +216,60 @@ def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
216216
assert sampled[24:28] == sampled[28:32] == sampled[32:36]
217217

218218

219+
class GetHighEntropyMaskTester(unittest.TestCase):
220+
def test_compute_entropy_mask_0(self):
221+
# We have a total of 12 tokens out of which 10 are non-pad.
222+
# for a top_entropy_quantile of 0.8, we expect the top 20% i.e 2 non-pad tokens corresponding to
223+
# the highest entropy to be unmasked.
224+
# In our example these will be the tokens corresponding to the entropies 0.9 and 1.0 since 1.1 and 1.2 are pad
225+
# tokens they are excluded from the entropy threshold calculation.
226+
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
227+
mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
228+
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.8)
229+
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0]], dtype=torch.bool)
230+
torch.testing.assert_close(entropy_mask, expected_mask)
231+
232+
def test_compute_entropy_mask_1(self):
233+
# Another example with a different set of entropies and a different mask.
234+
entropies = torch.tensor([[0.1, 0.2, 0.3, 1.4, 0.5, 0.14], [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]])
235+
mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]])
236+
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.8)
237+
expected_mask = torch.tensor([[0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool)
238+
torch.testing.assert_close(entropy_mask, expected_mask)
239+
240+
def test_compute_entropy_mask_lower_threshold(self):
241+
# For a threshold of 0.5 we expect the top half of the non-pad tokens to be unmasked.
242+
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
243+
mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
244+
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.5)
245+
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 1], [1, 1, 1, 1, 0, 0]], dtype=torch.bool)
246+
torch.testing.assert_close(entropy_mask, expected_mask)
247+
248+
def test_compute_entropy_threshold_0(self):
249+
# If the threshold is 0.0 then we expect the mask to be all ones for non-pad tokens.
250+
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
251+
mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
252+
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.0)
253+
expected_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]], dtype=torch.bool)
254+
torch.testing.assert_close(entropy_mask, expected_mask)
255+
256+
def test_compute_entropy_threshold_1(self):
257+
# If the threshold is 1.0 then we expect the mask to be all zeros BUT ONE VALUE.
258+
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
259+
mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
260+
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=1.0)
261+
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool)
262+
torch.testing.assert_close(entropy_mask, expected_mask)
263+
264+
def test_compute_entropy_all_masked(self):
265+
# If there are no non-pad tokens we expect the mask to be all zeros.
266+
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
267+
mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]])
268+
entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.5)
269+
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=torch.bool)
270+
torch.testing.assert_close(entropy_mask, expected_mask)
271+
272+
219273
class GRPOTrainerTester(unittest.TestCase):
220274
def test_init_minimal(self):
221275
# Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset
@@ -853,7 +907,7 @@ def test_training_with_entropy_filter(self):
853907
num_generations=3, # reduce the number of generations to reduce memory usage
854908
max_completion_length=8, # reduce the completion length to reduce memory usage
855909
report_to="none",
856-
token_entropy_percentile_threshold=0.8,
910+
top_entropy_quantile=0.2,
857911
)
858912
trainer = GRPOTrainer(
859913
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
@@ -1301,38 +1355,6 @@ def reward_func(completions, **kwargs):
13011355
)
13021356
trainer.train()
13031357

1304-
def test_compute_entropy_mask(self):
1305-
"""Test the _compute_entropy_mask method."""
1306-
trainer = GRPOTrainer(
1307-
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
1308-
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
1309-
args=GRPOConfig(token_entropy_percentile_threshold=0.8),
1310-
)
1311-
1312-
# Create dummy entropies and completion mask
1313-
entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
1314-
completion_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]])
1315-
1316-
entropy_mask = trainer._compute_entropy_mask(entropies, completion_mask)
1317-
1318-
self.assertEqual(entropy_mask.shape, entropies.shape)
1319-
1320-
# We have a total of 12 tokens out of which 10 are non-pad,
1321-
# for a token_entropy_percentile_threshold of 0.8,
1322-
# we expect the top 20% i.e 2 non-pad tokens corresponding to the highest entropy to be unmasked.
1323-
# In our example these will be the tokens corresponding to the entropies 0.9 and 1.0
1324-
# since 1.1 and 1.2 are pad tokens they are excluded from the entropy threshold calculation.
1325-
expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0]], dtype=torch.bool)
1326-
self.assertTrue(torch.equal(entropy_mask, expected_mask))
1327-
1328-
entropies = torch.tensor([[0.1, 0.2, 0.3, 1.4, 0.5, 0.14], [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]])
1329-
completion_mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]])
1330-
1331-
expected_mask = torch.tensor([[0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool)
1332-
entropy_mask = trainer._compute_entropy_mask(entropies, completion_mask)
1333-
1334-
self.assertTrue(torch.equal(entropy_mask, expected_mask))
1335-
13361358
def test_prepare_input_called_with_correct_data(self):
13371359
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
13381360
with tempfile.TemporaryDirectory() as tmp_dir:

trl/trainer/grpo_config.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,12 @@ class GRPOConfig(TrainingArguments):
193193
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
194194
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
195195
set `sync_ref_model=True`.
196-
token_entropy_percentile_threshold (`float`, *optional*, defaults to `0.0`):
197-
τ parameter from the [Beyond the 80/20 Rule](https://huggingface/papers/2506.01939) paper, which finds that
198-
masking out the bottom τ percentile of tokens based on the entropy of the probability distribution at a
199-
given sequence position, in the policy loss term yields better results. The range of the parameter is
200-
[0.0-1.0] a value of 0.0 means that none the tokens are filtered out and 1.0 means that all the tokens are
201-
masked. Recommended value is `0.8`.
196+
top_entropy_quantile (`float`, *optional*, defaults to `1.0`):
197+
ρ parameter from [Beyond the 80/20 Rule](https://huggingface.co/papers/2506.01939). Keeps in the policy
198+
loss term only the top-ρ quantile of tokens by entropy of the probability distribution at each sequence
199+
position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token;
200+
`1.0` keeps all tokens. The paper recommends a value of `0.2`.
201+
If used with `mask_truncated_completions=True`, only tokens from non-truncated completions are considered.
202202
use_liger_loss (`bool`, *optional*, defaults to `False`):
203203
Whether to use the Liger GRPO loss.
204204
@@ -520,12 +520,14 @@ class GRPOConfig(TrainingArguments):
520520
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
521521
},
522522
)
523-
token_entropy_percentile_threshold: float = field(
524-
default=0.0,
523+
top_entropy_quantile: float = field(
524+
default=1.0,
525525
metadata={
526-
"help": "Percentile threshold for filtering out tokens in the policy loss based on entropy."
527-
"Positions in the completion with entropy below this percentile are masked out."
528-
"0.8 is the recommended value if you'd like to enable entropy based masking."
526+
"help": "ρ parameter from Beyond the 80/20 Rule. Keeps in the policy loss term only the top-ρ quantile of "
527+
"tokens by entropy of the probability distribution at each sequence position, improving results. Range: "
528+
"[0.0-1.0]. A value of `1.0` masks all but the highest entropy token; `0.0` keeps all tokens. The paper "
529+
"recommends a value of `0.2`. If used with `mask_truncated_completions=True`, only tokens from "
530+
"non-truncated completions are considered."
529531
},
530532
)
531533
use_liger_loss: bool = field(

trl/trainer/grpo_trainer.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,32 @@ def identity(x):
290290
return x
291291

292292

293+
def get_high_entropy_mask(entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor:
294+
"""
295+
Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold.
296+
297+
Args:
298+
entropies (`torch.Tensor`):
299+
Tensor of shape (batch_size, seq_len) with per-token entropy values.
300+
mask (`torch.Tensor`):
301+
Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding.
302+
threshold (`float`):
303+
Quantile threshold between `0.0` and `1.0` to select high-entropy tokens.
304+
305+
Returns:
306+
`torch.Tensor`:
307+
Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold and
308+
`False` otherwise.
309+
"""
310+
non_pad_entropies = entropies[mask.bool()].float()
311+
if non_pad_entropies.numel() == 0:
312+
return torch.zeros_like(entropies, dtype=torch.bool)
313+
entropy_threshold = torch.quantile(non_pad_entropies, threshold)
314+
masked_entropies = entropies * mask.float()
315+
entropy_mask = masked_entropies >= entropy_threshold
316+
return entropy_mask & mask.bool() # ensure padding tokens are always masked out
317+
318+
293319
class GRPOTrainer(Trainer):
294320
"""
295321
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
@@ -523,8 +549,8 @@ def __init__(
523549
self.loss_type = args.loss_type
524550
self.scale_rewards = args.scale_rewards
525551
self.mask_truncated_completions = args.mask_truncated_completions
526-
self.token_entropy_percentile_threshold = args.token_entropy_percentile_threshold
527-
if self.use_liger_loss and self.token_entropy_percentile_threshold > 0.0:
552+
self.top_entropy_quantile = args.top_entropy_quantile
553+
if self.use_liger_loss and self.top_entropy_quantile < 1.0:
528554
raise NotImplementedError(
529555
"Liger Kernels don't currently support masking token positions based on entropy."
530556
)
@@ -906,7 +932,7 @@ def _get_per_token_logps_and_entropies(
906932

907933
logps = torch.cat(all_logps, dim=0)
908934
entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None
909-
return {"logps": logps, "entropies": entropies}
935+
return logps, entropies
910936

911937
def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
912938
extra_prefixes = extra_prefixes or []
@@ -1296,23 +1322,23 @@ def _generate_and_score_completions(
12961322
# old_per_token_logps to None.
12971323
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
12981324
if self.args.gradient_accumulation_steps % generate_every != 0:
1299-
old_per_token_logps = self._get_per_token_logps_and_entropies(
1325+
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
13001326
self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size
1301-
)["logps"]
1327+
)
13021328
else:
13031329
old_per_token_logps = None
13041330

13051331
# Compute the per-token log probabilities for the reference model
13061332
if self.beta != 0.0:
13071333
if self.ref_model is not None:
1308-
ref_per_token_logps = self._get_per_token_logps_and_entropies(
1334+
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
13091335
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
1310-
)["logps"]
1336+
)
13111337
else:
13121338
with self.accelerator.unwrap_model(self.model).disable_adapter():
1313-
ref_per_token_logps = self._get_per_token_logps_and_entropies(
1339+
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
13141340
self.model, prompt_completion_ids, attention_mask, logits_to_keep
1315-
)["logps"]
1341+
)
13161342
else:
13171343
ref_per_token_logps = None
13181344

@@ -1447,15 +1473,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
14471473
else:
14481474
return self._compute_loss(model, inputs)
14491475

1450-
def _compute_entropy_mask(self, entropies, completion_mask):
1451-
# compute the entropy threshold across all tokens in the batch
1452-
non_pad_entropies = entropies[completion_mask.bool()]
1453-
# disregard pad tokens when computing the entropy threshold
1454-
entropy_threshold = torch.quantile(non_pad_entropies.float(), self.token_entropy_percentile_threshold)
1455-
entropies = entropies * completion_mask.float() # mask out the padding tokens
1456-
entropy_mask = entropies >= entropy_threshold
1457-
return entropy_mask
1458-
14591476
def _compute_loss(self, model, inputs):
14601477
# Compute the per-token log probabilities for the model
14611478
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
@@ -1464,18 +1481,14 @@ def _compute_loss(self, model, inputs):
14641481
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
14651482
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
14661483

1467-
# Compute the entropy at each position in the completion
1468-
if self.token_entropy_percentile_threshold > 0.0:
1469-
logps_and_entropies = self._get_per_token_logps_and_entropies(
1470-
model, input_ids, attention_mask, logits_to_keep, compute_entropy=True
1471-
)
1472-
per_token_logps = logps_and_entropies["logps"]
1473-
entropies = logps_and_entropies["entropies"]
1474-
entropy_mask = self._compute_entropy_mask(entropies, completion_mask)
1484+
# Compute the per_token_logps and the entropy (if necessary) at each position in the completion
1485+
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
1486+
model, input_ids, attention_mask, logits_to_keep, compute_entropy=self.top_entropy_quantile < 1.0
1487+
)
1488+
1489+
if self.top_entropy_quantile < 1.0:
1490+
entropy_mask = get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile)
14751491
else:
1476-
per_token_logps = self._get_per_token_logps_and_entropies(
1477-
model, input_ids, attention_mask, logits_to_keep
1478-
)["logps"]
14791492
entropy_mask = None
14801493

14811494
# Compute the KL divergence between the model and the reference model

0 commit comments

Comments
 (0)