|
24 | 24 | from transformers.utils import is_peft_available
|
25 | 25 |
|
26 | 26 | from trl import GRPOConfig, GRPOTrainer
|
27 |
| -from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict |
| 27 | +from trl.trainer.grpo_trainer import RepeatSampler, get_high_entropy_mask, shuffle_tensor_dict, split_tensor_dict |
28 | 28 |
|
29 | 29 | from .testing_utils import require_vllm
|
30 | 30 |
|
@@ -216,6 +216,60 @@ def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
|
216 | 216 | assert sampled[24:28] == sampled[28:32] == sampled[32:36]
|
217 | 217 |
|
218 | 218 |
|
| 219 | +class GetHighEntropyMaskTester(unittest.TestCase): |
| 220 | + def test_compute_entropy_mask_0(self): |
| 221 | + # We have a total of 12 tokens out of which 10 are non-pad. |
| 222 | + # for a top_entropy_quantile of 0.8, we expect the top 20% i.e 2 non-pad tokens corresponding to |
| 223 | + # the highest entropy to be unmasked. |
| 224 | + # In our example these will be the tokens corresponding to the entropies 0.9 and 1.0 since 1.1 and 1.2 are pad |
| 225 | + # tokens they are excluded from the entropy threshold calculation. |
| 226 | + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) |
| 227 | + mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) |
| 228 | + entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.8) |
| 229 | + expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0]], dtype=torch.bool) |
| 230 | + torch.testing.assert_close(entropy_mask, expected_mask) |
| 231 | + |
| 232 | + def test_compute_entropy_mask_1(self): |
| 233 | + # Another example with a different set of entropies and a different mask. |
| 234 | + entropies = torch.tensor([[0.1, 0.2, 0.3, 1.4, 0.5, 0.14], [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]]) |
| 235 | + mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]]) |
| 236 | + entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.8) |
| 237 | + expected_mask = torch.tensor([[0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool) |
| 238 | + torch.testing.assert_close(entropy_mask, expected_mask) |
| 239 | + |
| 240 | + def test_compute_entropy_mask_lower_threshold(self): |
| 241 | + # For a threshold of 0.5 we expect the top half of the non-pad tokens to be unmasked. |
| 242 | + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) |
| 243 | + mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) |
| 244 | + entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.5) |
| 245 | + expected_mask = torch.tensor([[0, 0, 0, 0, 0, 1], [1, 1, 1, 1, 0, 0]], dtype=torch.bool) |
| 246 | + torch.testing.assert_close(entropy_mask, expected_mask) |
| 247 | + |
| 248 | + def test_compute_entropy_threshold_0(self): |
| 249 | + # If the threshold is 0.0 then we expect the mask to be all ones for non-pad tokens. |
| 250 | + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) |
| 251 | + mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) |
| 252 | + entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.0) |
| 253 | + expected_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]], dtype=torch.bool) |
| 254 | + torch.testing.assert_close(entropy_mask, expected_mask) |
| 255 | + |
| 256 | + def test_compute_entropy_threshold_1(self): |
| 257 | + # If the threshold is 1.0 then we expect the mask to be all zeros BUT ONE VALUE. |
| 258 | + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) |
| 259 | + mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) |
| 260 | + entropy_mask = get_high_entropy_mask(entropies, mask, threshold=1.0) |
| 261 | + expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool) |
| 262 | + torch.testing.assert_close(entropy_mask, expected_mask) |
| 263 | + |
| 264 | + def test_compute_entropy_all_masked(self): |
| 265 | + # If there are no non-pad tokens we expect the mask to be all zeros. |
| 266 | + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) |
| 267 | + mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]) |
| 268 | + entropy_mask = get_high_entropy_mask(entropies, mask, threshold=0.5) |
| 269 | + expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=torch.bool) |
| 270 | + torch.testing.assert_close(entropy_mask, expected_mask) |
| 271 | + |
| 272 | + |
219 | 273 | class GRPOTrainerTester(unittest.TestCase):
|
220 | 274 | def test_init_minimal(self):
|
221 | 275 | # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset
|
@@ -853,7 +907,7 @@ def test_training_with_entropy_filter(self):
|
853 | 907 | num_generations=3, # reduce the number of generations to reduce memory usage
|
854 | 908 | max_completion_length=8, # reduce the completion length to reduce memory usage
|
855 | 909 | report_to="none",
|
856 |
| - token_entropy_percentile_threshold=0.8, |
| 910 | + top_entropy_quantile=0.2, |
857 | 911 | )
|
858 | 912 | trainer = GRPOTrainer(
|
859 | 913 | model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
@@ -1301,38 +1355,6 @@ def reward_func(completions, **kwargs):
|
1301 | 1355 | )
|
1302 | 1356 | trainer.train()
|
1303 | 1357 |
|
1304 |
| - def test_compute_entropy_mask(self): |
1305 |
| - """Test the _compute_entropy_mask method.""" |
1306 |
| - trainer = GRPOTrainer( |
1307 |
| - model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
1308 |
| - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
1309 |
| - args=GRPOConfig(token_entropy_percentile_threshold=0.8), |
1310 |
| - ) |
1311 |
| - |
1312 |
| - # Create dummy entropies and completion mask |
1313 |
| - entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) |
1314 |
| - completion_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) |
1315 |
| - |
1316 |
| - entropy_mask = trainer._compute_entropy_mask(entropies, completion_mask) |
1317 |
| - |
1318 |
| - self.assertEqual(entropy_mask.shape, entropies.shape) |
1319 |
| - |
1320 |
| - # We have a total of 12 tokens out of which 10 are non-pad, |
1321 |
| - # for a token_entropy_percentile_threshold of 0.8, |
1322 |
| - # we expect the top 20% i.e 2 non-pad tokens corresponding to the highest entropy to be unmasked. |
1323 |
| - # In our example these will be the tokens corresponding to the entropies 0.9 and 1.0 |
1324 |
| - # since 1.1 and 1.2 are pad tokens they are excluded from the entropy threshold calculation. |
1325 |
| - expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0]], dtype=torch.bool) |
1326 |
| - self.assertTrue(torch.equal(entropy_mask, expected_mask)) |
1327 |
| - |
1328 |
| - entropies = torch.tensor([[0.1, 0.2, 0.3, 1.4, 0.5, 0.14], [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]]) |
1329 |
| - completion_mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]]) |
1330 |
| - |
1331 |
| - expected_mask = torch.tensor([[0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool) |
1332 |
| - entropy_mask = trainer._compute_entropy_mask(entropies, completion_mask) |
1333 |
| - |
1334 |
| - self.assertTrue(torch.equal(entropy_mask, expected_mask)) |
1335 |
| - |
1336 | 1358 | def test_prepare_input_called_with_correct_data(self):
|
1337 | 1359 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
1338 | 1360 | with tempfile.TemporaryDirectory() as tmp_dir:
|
|
0 commit comments