File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -464,6 +464,11 @@ def data_collator(features): # No data collation is needed in GRPO
464
464
num_processes = self .accelerator .num_processes
465
465
global_batch_size = args .per_device_train_batch_size * num_processes
466
466
possible_values = [n_gen for n_gen in range (2 , global_batch_size + 1 ) if (global_batch_size ) % n_gen == 0 ]
467
+ if self .num_generations < 2 :
468
+ raise ValueError (
469
+ f"GRPO requires at least 2 generations per prompt to calculate the advantages. "
470
+ f"You provided { self .num_generations } , which is less than the minimum required."
471
+ )
467
472
if self .num_generations not in possible_values :
468
473
raise ValueError (
469
474
f"The global train batch size ({ num_processes } x { args .per_device_train_batch_size } ) must be evenly "
You can’t perform that action at this time.
0 commit comments