Skip to content

Commit 013d360

Browse files
🔹 Fix: Miscalculated mask shape in comments (#2925)
1 parent e5ae703 commit 013d360

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def _generate_and_score_completions(
739739
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
740740

741741
# Concatenate prompt_mask with completion_mask for logit computation
742-
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
742+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
743743

744744
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
745745

0 commit comments

Comments
 (0)