Skip to content

Commit 4db63af

Browse files
committed
Fix GRPO unsqueeze advantages
1 parent ecb2811 commit 4db63af

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1803,7 +1803,7 @@ def _compute_loss(self, model, inputs):
18031803
# In the base GRPO implementation, advantages are expected to have shape (B,). To support subclasses that
18041804
# provide advantages with shape (B, T) (e.g., MiniLLM), we *conditionally* unsqueeze the tensor.
18051805
if advantages.dim() == 1:
1806-
advantages = advantages
1806+
advantages = advantages.unsqueeze(1)
18071807
# When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps,
18081808
# old_per_token_logps == per_token_logps. In this case we can skip its computation
18091809
# (see _generate_and_score_completions) and instead use per_token_logps.detach().

0 commit comments

Comments
 (0)