@@ -1038,3 +1038,34 @@ def test_training_with_mask_truncated_completions_all_masked(self):
10381038 for n , param in previous_trainable_params .items ():
10391039 new_param = trainer .model .get_parameter (n )
10401040 self .assertTrue (torch .equal (param , new_param ), f"Parameter { n } has changed." )
1041+
1042+ def test_training_num_generations_larger_than_batch_size (self ):
1043+ dataset = load_dataset ("trl-internal-testing/zen" , "standard_prompt_only" , split = "train" )
1044+
1045+ with tempfile .TemporaryDirectory () as tmp_dir :
1046+ training_args = GRPOConfig (
1047+ output_dir = tmp_dir ,
1048+ learning_rate = 0.1 , # increase the learning rate to speed up the test
1049+ per_device_train_batch_size = 3 , # reduce the batch size to reduce memory usage
1050+ max_completion_length = 8 , # reduce the completion length to reduce memory usage
1051+ num_generations = 6 , # the number of generations is larger than the batch size, but
1052+ gradient_accumulation_steps = 2 , # gradient accumulation should allow that
1053+ report_to = "none" ,
1054+ )
1055+ trainer = GRPOTrainer (
1056+ model = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" ,
1057+ reward_funcs = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" ,
1058+ args = training_args ,
1059+ train_dataset = dataset ,
1060+ )
1061+
1062+ previous_trainable_params = {n : param .clone () for n , param in trainer .model .named_parameters ()}
1063+
1064+ trainer .train ()
1065+
1066+ self .assertIsNotNone (trainer .state .log_history [- 1 ]["train_loss" ])
1067+
1068+ # Check that the params have changed
1069+ for n , param in previous_trainable_params .items ():
1070+ new_param = trainer .model .get_parameter (n )
1071+ self .assertFalse (torch .equal (param , new_param ), f"Parameter { n } has not changed." )
0 commit comments