Skip to content

Commit 04f6597

Browse files
Aladoroqgallouedec
andauthored
🌡️ Fix temperature inconsistency in GRPO trainer (#3029)
* fix temperature inconsistency in GRPO trainer * adding 1e-7 isn't necessary * comment --------- Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent e3244d2 commit 04f6597

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def data_collator(features): # No data collation is needed in GRPO
379379
self.max_prompt_length = args.max_prompt_length
380380
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
381381
self.num_generations = args.num_generations # = G in the GRPO paper
382+
self.temperature = args.temperature
382383
self.use_vllm = args.use_vllm
383384

384385
# Multi-step
@@ -658,7 +659,10 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
658659
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
659660
# See https://github.com/huggingface/trl/issues/2770
660661
logits = logits[:, -logits_to_keep:]
661-
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
662+
# Divide logits by sampling temperature.
663+
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
664+
logits = logits / self.temperature
665+
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
662666

663667
@profiling_decorator
664668
def _move_model_to_vllm(self):

0 commit comments

Comments
 (0)