Skip to content

Commit 24523b0

Browse files
Aladoroqgallouedec
authored andcommitted
🌡️ Fix temperature inconsistency in GRPO trainer (huggingface#3029)
* fix temperature inconsistency in GRPO trainer * adding 1e-7 isn't necessary * comment --------- Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent a97d380 commit 24523b0

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,7 @@ def data_collator(features): # No data collation is needed in GRPO
847847
self.max_prompt_length = args.max_prompt_length
848848
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
849849
self.num_generations = args.num_generations # = G in the GRPO paper
850+
self.temperature = args.temperature
850851
self.use_vllm = args.use_vllm
851852
self.use_sglang = getattr(args, "use_sglang", False) # Add backend selection flag
852853

@@ -1214,7 +1215,10 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
12141215
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
12151216
# See https://github.com/huggingface/trl/issues/2770
12161217
logits = logits[:, -logits_to_keep:]
1217-
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
1218+
# Divide logits by sampling temperature.
1219+
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
1220+
logits = logits / self.temperature
1221+
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
12181222

12191223
def _update_sglang_engine_weights(self):
12201224
"""Update the SGLang engine weights from the current model."""

0 commit comments

Comments
 (0)