Skip to content

✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text #3698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 7, 2025

Conversation

pramodith
Copy link
Collaborator

@pramodith pramodith commented Jul 5, 2025

What does this PR do?

This PR fixes #3686 as mentioned in the discussion, when pad_tokens are a part of the original text the tokenizer doesn't set the attention mask corresponding to those tokens to 0, leading to erroneous logit scores and thereby generations.

tokenizer = AutoProcessor.from_pretrained("Qwen/Qwen3-0.6B")
tokenizer.pad_token
>>> <|endoftext|>

text_with_pad = "<|endoftext|><|endoftext|><|im_start|>system\nYou are a music teacher<|im_end|>\n<|im_start|>user\nDo re mi fa so la<|im_end|>\n<|im_start|>assistant\n'"
text_without_pad = "<|im_start|>system\nYou are a music teacher<|im_end|>\n<|im_start|>user\nDo re mi fa so la<|im_end|>\n<|im_start|>assistant\n'"

 inputs = tokenizer([text_with_pad, text_without_pad], padding=True, padding_side="left")

assert inputs["input_ids"][0] == inputs["input_ids"][1], "Input ids don't match!"
assert inputs["attention_mask"][0] == inputs["attention_mask"][1], "Attention masks don't match!"
>>> AssertionError: Attention masks don't match!

print(f"Attention masks of original text with pad/eos tokens\n {inputs['attention_mask'][0]}")
print(f"Attention masks of original text without pad/eos tokens\n {inputs['attention_mask'][1]}")
>>> Attention masks of original text with pad/eos tokens
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Attention masks of original text without pad/eos tokens
 [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

This PR addresses the issue by stripping out all the leading pad_tokens in the prompts_text list after prompt truncation happens and before the prompts are sent to vllm.

Fixes #3686

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif
Copy link
Collaborator

kashif commented Jul 7, 2025

thanks @pramodith perhaps its more robust, if we can create a mask for positions that contain pad tokens and then set the attention mask to 0 for all pad token positions, regardless of whether they were originally in the text or added by the tokenizer for padding:

if self.processing_class.pad_token_id is not None:
    is_pad_token = prompt_ids == self.processing_class.pad_token_id
    prompt_mask = prompt_mask & (~is_pad_token)

what do you think?

@pramodith
Copy link
Collaborator Author

pramodith commented Jul 7, 2025

thanks @pramodith perhaps its more robust, if we can create a mask for positions that contain pad tokens and then set the attention mask to 0 for all pad token positions, regardless of whether they were originally in the text or added by the tokenizer for padding:

if self.processing_class.pad_token_id is not None:
    is_pad_token = prompt_ids == self.processing_class.pad_token_id
    prompt_mask = prompt_mask & (~is_pad_token)

what do you think?

@kashif I think that'd be the ideal approach, but this issue only pops up when users use vllm. Since when using the HF backend we do correctly pass in the attention_mask, and prompt_ids but for the vllm instance/server we need to pass in the prompt_text or prompt_ids. I don't think vllm lets you pass an attention_mask as input.

all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,

if self.vllm_tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts_text)
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
else:
all_prompts_text = prompts_text
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False)

@kashif
Copy link
Collaborator

kashif commented Jul 7, 2025

@pramodith right so the full logic would be:

        # mask all pad tokens regardless of origin
        if self.processing_class.pad_token_id is not None:
            is_pad_token = prompt_ids == self.processing_class.pad_token_id
            prompt_mask = prompt_mask & (~is_pad_token)

        if self.max_prompt_length is not None:
            prompt_ids = prompt_ids[:, -self.max_prompt_length :]
            prompt_mask = prompt_mask[:, -self.max_prompt_length :]
            prompts_text = self.processing_class.batch_decode(
                prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
            )

         if self.use_vllm and self.processing_class.pad_token is not None:

            def remove_leading_pad_tokens(text, pad_token):
                while text.startswith(pad_token):
                    text = text[len(pad_token) :]
                return text

            prompts_text = [remove_leading_pad_tokens(text, self.processing_class.pad_token) for text in prompts_text]

@pramodith
Copy link
Collaborator Author

Cool, I'll make those changes by the end of the day.

@kashif
Copy link
Collaborator

kashif commented Jul 7, 2025

here is a test that fails if i do not mask out the padding tokens, feel free to add this to your PR as well:

    def test_generate_and_score_completions_prompt_mask(self):
        """
        Test that _generate_and_score_completions correctly masks pad tokens in prompt_mask.
        """

        def reward_func(completions, **kwargs):
            return [1.0] * len(completions)

        dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train[:2]")

        trainer = GRPOTrainer(
            model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
            reward_funcs=reward_func,
            args=GRPOConfig(
                output_dir=tempfile.mkdtemp(),
                per_device_train_batch_size=2,
                num_generations=2,
                max_steps=1,
                max_completion_length=4,
                report_to="none",
                logging_strategy="no",
            ),
            train_dataset=dataset,
        )

        # Create input with pad tokens already in the prompt text (duplicated for num_generations)
        pad_token = trainer.processing_class.pad_token
        inputs = [{"prompt": f"{pad_token}{pad_token}Hello"}] * trainer.num_generations

        # Call the method directly
        with torch.no_grad():
            outputs = trainer._generate_and_score_completions(inputs)

        # Check that all pad token positions in prompt_mask are 0
        prompt_ids = outputs["prompt_ids"]
        prompt_mask = outputs["prompt_mask"]
        pad_positions = prompt_ids == trainer.processing_class.pad_token_id

        self.assertTrue(
            torch.all(prompt_mask[pad_positions] == 0), "All pad token positions should have prompt_mask=0"
        )

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

qgallouedec commented Jul 7, 2025

import tempfile
import torch
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig


def reward_func(completions, **kwargs):
    return [1.0] * len(completions)


dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train[:2]")

trainer = GRPOTrainer(
    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
    reward_funcs=reward_func,
    args=GRPOConfig(
        output_dir=tempfile.mkdtemp(),
        per_device_train_batch_size=2,
        num_generations=2,
        max_steps=1,
        max_prompt_length=8,
        report_to="none",
        logging_strategy="no",
    ),
    train_dataset=dataset,
)

# Create input that requires padding
inputs = [
    {"prompt": "Hello"},
    {"prompt": "Hello how are you?"},
]

# Call the method directly
with torch.no_grad():
    outputs = trainer._generate_and_score_completions(inputs)

# Conversational
inputs = [
    {"prompt": [{"role": "user", "content": "Hello"}]},
    {"prompt": [{"role": "user", "content": "Hello how are you?"}]},
]

# Call the method directly
with torch.no_grad():
    outputs = trainer._generate_and_score_completions(inputs)

printing prompts_text:

before:

['<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Hello',
 'Hello how are you?']
['<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n',
 '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello how are you?<|im_end|>\n<|im_start|>assistant\n']

after:

['Hello',
 'Hello how are you?']

['<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n',
 '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello how are you?<|im_end|>\n<|im_start|>assistant\n']

@qgallouedec
Copy link
Member

qgallouedec commented Jul 7, 2025

Forwarding from our internal slack:

Ok I see:
The original text from the dataset doesn’t contain any [PAD] tokens. But [PAD] can show up in the decoded prompt if max_length is set — and that decoded prompt is what’s used for generation.

Here’s the difference:

  • If max_length is None:
    We tokenize and pad the input, but the original (clean) text is still used for generation → no [PAD] in the prompt.

  • If max_length is set:
    We tokenize → pad → truncate to max_length → decode → and now the decoded prompt may include [PAD] tokens.
    Since this decoded text is used for generation, those [PAD] tokens can leak into the generation input.

So when max_length is used, we need to manually strip [PAD] tokens from the decoded text before passing it into generation.

@qgallouedec qgallouedec changed the title [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text ✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text Jul 7, 2025
@qgallouedec qgallouedec merged commit b674989 into huggingface:main Jul 7, 2025
9 of 10 checks passed
qgallouedec added a commit that referenced this pull request Jul 8, 2025
…ns in truncated prompt text (#3698)

Co-authored-by: Shirin Yamani <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
marcandrelarochelle pushed a commit to marcandrelarochelle/trl that referenced this pull request Jul 29, 2025
…ns in truncated prompt text (huggingface#3698)

Co-authored-by: Shirin Yamani <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

question about prompts_text processing
5 participants