Skip to content

✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text #3698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 7, 2025
43 changes: 42 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,6 @@ def test_training_with_additional_generation_kwargs(self):
min_p=0.01,
repetition_penalty=1.1,
)

trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
Expand Down Expand Up @@ -1297,3 +1296,45 @@ def reward_func(completions, **kwargs):
train_dataset=dataset,
)
trainer.train()

def test_generate_and_score_completions_prompt_mask(self):
"""
Test that _generate_and_score_completions correctly masks pad tokens in prompt_mask.
"""

def reward_func(completions, **kwargs):
return [1.0] * len(completions)

dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train[:2]")

trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=GRPOConfig(
output_dir=tempfile.mkdtemp(),
per_device_train_batch_size=2,
num_generations=2,
max_steps=1,
max_completion_length=4,
report_to="none",
logging_strategy="no",
),
train_dataset=dataset,
)

# Create input with pad tokens already in the prompt text (duplicated for num_generations)
pad_token = trainer.processing_class.pad_token
inputs = [{"prompt": f"{pad_token}{pad_token}Hello"}] * trainer.num_generations

# Call the method directly
with torch.no_grad():
outputs = trainer._generate_and_score_completions(inputs)

# Check that all pad token positions in prompt_mask are 0
prompt_ids = outputs["prompt_ids"]
prompt_mask = outputs["prompt_mask"]
pad_positions = prompt_ids == trainer.processing_class.pad_token_id

self.assertTrue(
torch.all(prompt_mask[pad_positions] == 0), "All pad token positions should have prompt_mask=0"
)
15 changes: 15 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,13 +1090,28 @@ def _generate_and_score_completions(
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

# mask all pad tokens regardless of origin
if self.processing_class.pad_token_id is not None:
is_pad_token = prompt_ids == self.processing_class.pad_token_id
prompt_mask = prompt_mask & (~is_pad_token)

if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

# For vLLM backends clean the text by removing leading pad tokens
if self.use_vllm and self.processing_class.pad_token is not None:

def remove_leading_pad_tokens(text, pad_token):
while text.startswith(pad_token):
text = text[len(pad_token) :]
return text

prompts_text = [remove_leading_pad_tokens(text, self.processing_class.pad_token) for text in prompts_text]

# Generate completions using either vLLM or regular generation
if self.use_vllm:
# First, update the vLLM weights if needed
Expand Down
Loading