Skip to content

🎲 [GRPO] Make training dataset shuffle optional #3334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 103 additions & 68 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.utils import is_peft_available

from trl import GRPOConfig, GRPOTrainer
from trl.trainer.grpo_trainer import RepeatRandomSampler
from trl.trainer.grpo_trainer import RepeatSampler

from .testing_utils import require_vllm

Expand All @@ -33,10 +33,10 @@
from peft import LoraConfig, PeftModel


class RepeatRandomSamplerTester(unittest.TestCase):
def test_sampler(self):
class RepeatSamplerTester(unittest.TestCase):
def test_sampler_shuffle(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2)
sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=True)
# Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5]
sampled = list(sampler)
# Check that the length is doubled
Expand All @@ -46,93 +46,128 @@ def test_sampler(self):
# Check that each element is repeated twice
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))

def test_sampler_noshuffle(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False)
sampled = list(sampler)
expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]
self.assertEqual(sampled, expected)

def test_sampler_no_repeat(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1)
# Should output something like [4, 3, 0, 1, 2, 6, 5]
sampler = RepeatSampler(dataset, mini_repeat_count=1, shuffle=False)
sampled = list(sampler)
# Check that the length is the same
assert len(sampled) == len(dataset)
# Check that all indexes are present
assert set(sampled) == set(range(len(dataset)))
expected = [0, 1, 2, 3, 4, 5, 6]
self.assertEqual(sampled, expected)

def test_sampler_with_batch_size(self):
dataset = ["a", "b", "c", "d", "e", "f", "g", "h"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7]
sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2, shuffle=False)
sampled = list(sampler)
# Check that the length is doubled
assert len(sampled) == 2 * len(dataset)
# Check that all indexes are present
assert set(sampled) == set(range(len(dataset)))
# Check that each element is repeated as expected
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
expected = [0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7]
self.assertEqual(sampled, expected)

def test_sampler_with_batch_size_and_drop(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6]
sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2, shuffle=False)
sampled = list(sampler)
# Check that the length is doubled
assert len(sampled) == 2 * (
len(dataset) - 1
) # one element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
expected = [0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5]
self.assertEqual(sampled, expected)

def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
# Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
# 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2, shuffle=False)
sampled = list(sampler)
# Check that the length is quadrupled
assert len(sampled) == 4 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
# Check that the batch is repeated as expected
assert sampled[0:6] == sampled[6:12]
assert sampled[12:18] == sampled[18:24]
expected = [0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 3, 3, 4, 4, 5, 5]
self.assertEqual(sampled, expected)

def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
# Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3,
# 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
# 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6]
sampler = RepeatSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2, shuffle=False)
sampled = list(sampler)
# Check that the length is sextupled
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3))
# Check that the batch is repeated as expected
assert sampled[0:6] == sampled[6:12]
assert sampled[12:18] == sampled[18:24]
assert sampled[24:30] == sampled[30:36]
expected = [
0,
0,
0,
1,
1,
1,
0,
0,
0,
1,
1,
1,
2,
2,
2,
3,
3,
3,
2,
2,
2,
3,
3,
3,
4,
4,
4,
5,
5,
5,
4,
4,
4,
5,
5,
5,
]
self.assertEqual(sampled, expected)

def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
# Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3,
# 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
# 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6]
sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3, shuffle=False)
sampled = list(sampler)
# Check that the length is sextupled
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
# Check that the batch is repeated as expected
assert sampled[0:4] == sampled[4:8] == sampled[8:12]
assert sampled[12:16] == sampled[16:20] == sampled[20:24]
assert sampled[24:28] == sampled[28:32] == sampled[32:36]
expected = [
0,
0,
1,
1,
0,
0,
1,
1,
0,
0,
1,
1,
2,
2,
3,
3,
2,
2,
3,
3,
2,
2,
3,
3,
4,
4,
5,
5,
4,
4,
5,
5,
4,
4,
5,
5,
]
self.assertEqual(sampled, expected)


class GRPOTrainerTester(unittest.TestCase):
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class GRPOConfig(TrainingArguments):
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
with vLLM generation.
shuffle_dataset (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training dataset.

> Parameters that control generation

Expand Down Expand Up @@ -222,6 +224,10 @@ class GRPOConfig(TrainingArguments):
"is not compatible with vLLM generation."
},
)
shuffle_dataset: Optional[bool] = field(
default=True,
metadata={"help": "Whether to shuffle the training dataset."},
)

# Parameters that control generation
temperature: float = field(
Expand Down
28 changes: 20 additions & 8 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


class RepeatRandomSampler(Sampler):
class RepeatSampler(Sampler):
"""
Sampler that repeats the indices of a dataset in a structured manner.

Expand All @@ -91,6 +91,8 @@ class RepeatRandomSampler(Sampler):
Number of unique indices per batch.
repeat_count (`int`, *optional*, defaults to `1`):
Number of times to repeat the full sampling process.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the dataset.
seed (`int` or `None`, *optional*, defaults to `None`):
Random seed for reproducibility (only affects this sampler).

Expand Down Expand Up @@ -132,21 +134,28 @@ def __init__(
mini_repeat_count: int,
batch_size: int = 1,
repeat_count: int = 1,
shuffle: bool = True,
seed: Optional[int] = None,
):
self.data_source = data_source
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.num_samples = len(data_source)
self.shuffle = shuffle
self.seed = seed
self.generator = torch.Generator() # Create a local random generator
if seed is not None:
self.generator.manual_seed(seed)

if shuffle:
self.generator = torch.Generator() # Create a local random generator
if seed is not None:
self.generator.manual_seed(seed)

def __iter__(self):
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
if self.shuffle:
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
else:
indexes = list(range(self.num_samples))

# [2, 4, 3, 1, 0, 6, 5]
# -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
Expand Down Expand Up @@ -481,6 +490,8 @@ def data_collator(features): # No data collation is needed in GRPO
self.mask_truncated_completions = args.mask_truncated_completions

# Datasets
self.shuffle_dataset = args.shuffle_dataset

if (
isinstance(train_dataset, IterableDataset)
or isinstance(eval_dataset, IterableDataset)
Expand Down Expand Up @@ -727,17 +738,18 @@ def _get_train_sampler(self) -> Sampler:
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
return RepeatRandomSampler(
return RepeatSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=effective_batch_size // self.num_generations,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
shuffle=self.shuffle_dataset,
seed=self.args.seed,
)

def _get_eval_sampler(self, eval_dataset) -> Sampler:
# See _get_train_sampler for an explanation of the sampler.
return RepeatRandomSampler(
return RepeatSampler(
data_source=eval_dataset,
mini_repeat_count=self.num_generations,
seed=self.args.seed,
Expand Down