Skip to content

Commit 6a02c69

Browse files
nopepperRobert Veresqgallouedec
authored
🎲 Add support for additional generation kwargs in GRPO Trainer (#2989)
* Add support for additional generation kwargs in GRPO Trainer - Extend GRPOConfig to support additional generation kwargs - Update GRPOTrainer to incorporate additional generation parameters - Add tests for training with additional generation kwargs for both standard and vLLM modes * Add missing vllm_gpu_memory_utilization=0.5 * 🔧 Refactor GRPO generation parameters and configuration - Restructure GRPOConfig to separate generation parameters - Add support for top_p, top_k, min_p, repetition_penalty, and length_penalty - Remove additional_generation_kwargs in favor of explicit parameters - Update GRPOTrainer to use new generation parameter configuration * Update tests * Remove length_penalty and fix tests * Update defaults and docs - Change temperature type from Optional[float] to float - Set default top_p to 1.0 instead of None - Simplify parameter descriptions by removing redundant "if set to None" text - Maintain consistent type hints and default values for generation parameters * GRPO remove optional type hint for temperature parameter * Remove length_penalty from sampling_kwargs dict in GRPOTrainer * some refactoring * top k None support * change value of in test to amke them work --------- Co-authored-by: Robert Veres <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent a1c58aa commit 6a02c69

File tree

3 files changed

+142
-8
lines changed

3 files changed

+142
-8
lines changed

tests/test_grpo_trainer.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,7 @@ def test_training_vllm_guided_decoding(self):
780780
report_to="none",
781781
use_vllm=True,
782782
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
783+
vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm
783784
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
784785
)
785786
trainer = GRPOTrainer(
@@ -799,3 +800,80 @@ def test_training_vllm_guided_decoding(self):
799800
for n, param in previous_trainable_params.items():
800801
new_param = trainer.model.get_parameter(n)
801802
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
803+
804+
def test_training_with_additional_generation_kwargs(self):
805+
"""Test that training works with additional generation kwargs."""
806+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
807+
808+
with tempfile.TemporaryDirectory() as tmp_dir:
809+
training_args = GRPOConfig(
810+
output_dir=tmp_dir,
811+
learning_rate=0.1, # increase the learning rate to speed up the test
812+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
813+
num_generations=3, # reduce the number of generations to reduce memory usage
814+
max_completion_length=32, # reduce the completion length to reduce memory usage
815+
report_to="none",
816+
top_p=0.9,
817+
top_k=10,
818+
min_p=0.01,
819+
repetition_penalty=1.1,
820+
)
821+
822+
trainer = GRPOTrainer(
823+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
824+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
825+
args=training_args,
826+
train_dataset=dataset,
827+
)
828+
829+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
830+
831+
trainer.train()
832+
833+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
834+
835+
# Check that the params have changed
836+
for n, param in previous_trainable_params.items():
837+
new_param = trainer.model.get_parameter(n)
838+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
839+
840+
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
841+
@require_torch_accelerator
842+
def test_training_vllm_with_additional_generation_kwargs(self):
843+
"""Test that training works with vLLM and additional generation kwargs."""
844+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
845+
846+
with tempfile.TemporaryDirectory() as tmp_dir:
847+
training_args = GRPOConfig(
848+
output_dir=tmp_dir,
849+
learning_rate=0.1, # increase the learning rate to speed up the test
850+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
851+
num_generations=3, # reduce the number of generations to reduce memory usage
852+
max_completion_length=32, # reduce the completion length to reduce memory usage
853+
report_to="none",
854+
use_vllm=True,
855+
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
856+
vllm_gpu_memory_utilization=0.5, # reduce since because we use the same device for training and vllm
857+
top_p=0.9,
858+
top_k=10,
859+
min_p=0.01,
860+
repetition_penalty=1.1,
861+
)
862+
863+
trainer = GRPOTrainer(
864+
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
865+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
866+
args=training_args,
867+
train_dataset=dataset,
868+
)
869+
870+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
871+
872+
trainer.train()
873+
874+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
875+
876+
# Check that the params have changed
877+
for n, param in previous_trainable_params.items():
878+
new_param = trainer.model.get_parameter(n)
879+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

trl/trainer/grpo_config.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ class GRPOConfig(TrainingArguments):
4747
num_generations (`int` or `None`, *optional*, defaults to `8`):
4848
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
4949
must be divisible by this value.
50-
temperature (`float`, *optional*, defaults to `0.9`):
51-
Temperature for sampling. The higher the temperature, the more random the completions.
5250
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
5351
Maximum length of the generated completion.
5452
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
@@ -57,6 +55,24 @@ class GRPOConfig(TrainingArguments):
5755
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
5856
with vLLM generation.
5957
58+
> Parameters that control generation
59+
60+
temperature (`float`, defaults to `0.9`):
61+
Temperature for sampling. The higher the temperature, the more random the completions.
62+
top_p (`float`, *optional*, defaults to `1.0`):
63+
Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
64+
`1.0` to consider all tokens.
65+
top_k (`int` or `None`, *optional*, defaults to `50`):
66+
Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
67+
disabled.
68+
min_p (`float` or `None`, *optional*, defaults to `None`):
69+
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
70+
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
71+
repetition_penalty (`float`, *optional*, defaults to `1.0`):
72+
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
73+
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
74+
tokens.
75+
6076
> Parameters that control generation acceleration powered by vLLM
6177
6278
use_vllm (`bool`, *optional*, defaults to `False`):
@@ -115,6 +131,7 @@ class GRPOConfig(TrainingArguments):
115131
set `sync_ref_model=True`.
116132
117133
> Parameters that control the logging
134+
118135
log_completions (`bool`, *optional*, defaults to `False`):
119136
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
120137
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
@@ -152,10 +169,6 @@ class GRPOConfig(TrainingArguments):
152169
"must be divisible by this value."
153170
},
154171
)
155-
temperature: Optional[float] = field(
156-
default=0.9,
157-
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
158-
)
159172
max_completion_length: Optional[int] = field(
160173
default=256,
161174
metadata={"help": "Maximum length of the generated completion."},
@@ -170,6 +183,41 @@ class GRPOConfig(TrainingArguments):
170183
},
171184
)
172185

186+
# Parameters that control generation
187+
temperature: float = field(
188+
default=0.9,
189+
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
190+
)
191+
top_p: float = field(
192+
default=1.0,
193+
metadata={
194+
"help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
195+
"Set to 1.0 to consider all tokens."
196+
},
197+
)
198+
top_k: Optional[int] = field(
199+
default=50,
200+
metadata={
201+
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
202+
"top-k-filtering is disabled."
203+
},
204+
)
205+
min_p: Optional[float] = field(
206+
default=None,
207+
metadata={
208+
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
209+
"must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
210+
},
211+
)
212+
repetition_penalty: float = field(
213+
default=1.0,
214+
metadata={
215+
"help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
216+
"text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
217+
"to repeat tokens."
218+
},
219+
)
220+
173221
# Parameters that control generation acceleration powered by vLLM
174222
use_vllm: Optional[bool] = field(
175223
default=False,

trl/trainer/grpo_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,10 +519,14 @@ def new_group_context():
519519

520520
# Sampling parameters
521521
self.sampling_params = SamplingParams(
522-
temperature=args.temperature,
523522
max_tokens=self.max_completion_length,
524523
guided_decoding=guided_decoding,
525524
n=args.num_generations,
525+
temperature=args.temperature,
526+
top_p=args.top_p,
527+
top_k=-1 if args.top_k is None else args.top_k,
528+
min_p=0.0 if args.min_p is None else args.min_p,
529+
repetition_penalty=args.repetition_penalty,
526530
)
527531

528532
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
@@ -535,8 +539,12 @@ def new_group_context():
535539
self.generation_config = GenerationConfig(
536540
max_new_tokens=self.max_completion_length,
537541
do_sample=True,
538-
temperature=args.temperature,
539542
pad_token_id=processing_class.pad_token_id,
543+
temperature=args.temperature,
544+
top_p=args.top_p,
545+
top_k=args.top_k,
546+
min_p=args.min_p,
547+
repetition_penalty=args.repetition_penalty,
540548
)
541549

542550
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the

0 commit comments

Comments
 (0)