-
Notifications
You must be signed in to change notification settings - Fork 2.1k
✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text #3698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text #3698
Conversation
thanks @pramodith perhaps its more robust, if we can create a mask for positions that contain pad tokens and then set the attention mask to 0 for all pad token positions, regardless of whether they were originally in the text or added by the tokenizer for padding: if self.processing_class.pad_token_id is not None:
is_pad_token = prompt_ids == self.processing_class.pad_token_id
prompt_mask = prompt_mask & (~is_pad_token) what do you think? |
@kashif I think that'd be the ideal approach, but this issue only pops up when users use vllm. Since when using the HF backend we do correctly pass in the trl/trl/trainer/grpo_trainer.py Lines 1109 to 1118 in d98d539
trl/trl/trainer/grpo_trainer.py Lines 1160 to 1171 in d98d539
|
@pramodith right so the full logic would be: # mask all pad tokens regardless of origin
if self.processing_class.pad_token_id is not None:
is_pad_token = prompt_ids == self.processing_class.pad_token_id
prompt_mask = prompt_mask & (~is_pad_token)
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
if self.use_vllm and self.processing_class.pad_token is not None:
def remove_leading_pad_tokens(text, pad_token):
while text.startswith(pad_token):
text = text[len(pad_token) :]
return text
prompts_text = [remove_leading_pad_tokens(text, self.processing_class.pad_token) for text in prompts_text] |
Cool, I'll make those changes by the end of the day. |
here is a test that fails if i do not mask out the padding tokens, feel free to add this to your PR as well: def test_generate_and_score_completions_prompt_mask(self):
"""
Test that _generate_and_score_completions correctly masks pad tokens in prompt_mask.
"""
def reward_func(completions, **kwargs):
return [1.0] * len(completions)
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train[:2]")
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=GRPOConfig(
output_dir=tempfile.mkdtemp(),
per_device_train_batch_size=2,
num_generations=2,
max_steps=1,
max_completion_length=4,
report_to="none",
logging_strategy="no",
),
train_dataset=dataset,
)
# Create input with pad tokens already in the prompt text (duplicated for num_generations)
pad_token = trainer.processing_class.pad_token
inputs = [{"prompt": f"{pad_token}{pad_token}Hello"}] * trainer.num_generations
# Call the method directly
with torch.no_grad():
outputs = trainer._generate_and_score_completions(inputs)
# Check that all pad token positions in prompt_mask are 0
prompt_ids = outputs["prompt_ids"]
prompt_mask = outputs["prompt_mask"]
pad_positions = prompt_ids == trainer.processing_class.pad_token_id
self.assertTrue(
torch.all(prompt_mask[pad_positions] == 0), "All pad token positions should have prompt_mask=0"
) |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
import tempfile
import torch
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
def reward_func(completions, **kwargs):
return [1.0] * len(completions)
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train[:2]")
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=GRPOConfig(
output_dir=tempfile.mkdtemp(),
per_device_train_batch_size=2,
num_generations=2,
max_steps=1,
max_prompt_length=8,
report_to="none",
logging_strategy="no",
),
train_dataset=dataset,
)
# Create input that requires padding
inputs = [
{"prompt": "Hello"},
{"prompt": "Hello how are you?"},
]
# Call the method directly
with torch.no_grad():
outputs = trainer._generate_and_score_completions(inputs)
# Conversational
inputs = [
{"prompt": [{"role": "user", "content": "Hello"}]},
{"prompt": [{"role": "user", "content": "Hello how are you?"}]},
]
# Call the method directly
with torch.no_grad():
outputs = trainer._generate_and_score_completions(inputs) printing before:
after:
|
Forwarding from our internal slack: Ok I see: Here’s the difference:
So when |
…ns in truncated prompt text (#3698) Co-authored-by: Shirin Yamani <[email protected]> Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
…ns in truncated prompt text (huggingface#3698) Co-authored-by: Shirin Yamani <[email protected]> Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
What does this PR do?
This PR fixes #3686 as mentioned in the discussion, when
pad_tokens
are a part of the original text the tokenizer doesn't set the attention mask corresponding to those tokens to 0, leading to erroneous logit scores and thereby generations.This PR addresses the issue by stripping out all the leading
pad_tokens
in theprompts_text
list after prompt truncation happens and before the prompts are sent to vllm.Fixes #3686
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.