Skip to content

Commit 497cdc1

Browse files
📢 Improve GRPO trainer error message for invalid num_generations (huggingface#3199)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 4bb0032 commit 497cdc1

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,11 @@ def data_collator(features): # No data collation is needed in GRPO
464464
num_processes = self.accelerator.num_processes
465465
global_batch_size = args.per_device_train_batch_size * num_processes
466466
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
467+
if self.num_generations < 2:
468+
raise ValueError(
469+
f"GRPO requires at least 2 generations per prompt to calculate the advantages. "
470+
f"You provided {self.num_generations}, which is less than the minimum required."
471+
)
467472
if self.num_generations not in possible_values:
468473
raise ValueError(
469474
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "

0 commit comments

Comments
 (0)