-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Ensure Chat Template Safe Prompt Truncation #3646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure Chat Template Safe Prompt Truncation #3646
Conversation
Do you plan to split the |
Cool, I'll have a new PR with just |
@LeonEricsson this one is ready for review now. I've mentioned some of the tradeoffs this solution entails in the description of the PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested running with the following conversational dataset:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
def get_gsm8k_questions(split="train") -> Dataset:
data = load_dataset("openai/gsm8k", "main")[split] # type: ignore
data = data.map(
lambda x: { # type: ignore
"prompt": [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": x["question"]}],
"answer": extract_hash_answer(x["answer"]),
}
) # type: ignore
return data # type: ignore
and the prompts_text
returned by _get_prompt_inputs
had the wrong system prompt. They contained the tokenizer's default system prompt instead.
trl/trainer/grpo_config.py
Outdated
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation | ||
"and do not account for any tokens introduced |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation | |
"and do not account for any tokens introduced | |
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation " | |
"and do not account for any tokens introduced " |
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] | ||
|
||
if self.max_prompt_length is not None and prompt_mask.sum(-1).max() > self.max_prompt_length: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you motivate the need for prompt_mask.sum(-1).max() > self.max_prompt_length
? is it to avoid unnecessary decode if we don't need to truncate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly, gets rid of any redundant ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Given we've already padded, won't prompt_mask.sum(-1).max()
always equal prompt_ids.shape[-1]
or prompt_mask.shape[-1]
(we're doing 'longest' padding - pad to the longest sequence in the batch)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's fair. I'll change that to just use the shape.
trl/trainer/grpo_trainer.py
Outdated
prompt_inputs = self.processing_class.apply_chat_template( | ||
truncated_messages, return_dict=True, add_generation_prompt=True | ||
) | ||
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
prompt_ids = prompt_inputs["input_ids"] | ||
prompt_mask = prompt_inputs["attention_mask"] | ||
prompt_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be outside/after the for-loop, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, good catch.
What was the So I'm assuming that the But I suppose that would make this approach inconsistent with how the non-chat prompt truncation is handled. Let me know what you think the best approach is. |
yes good catch, i believe this was indeed the case.
tbh, I'm not entirely sure what the correct approach is here. I feel like user content should indeed be truncated first. However, I agree that we're moving into territory where handling conversational versus non-conversational formats significantly diverges. |
Do we want to rope in someone else for a third opinion? I don't think there's a right approach/answer here prompt truncation is going to always be less than ideal in terms of how it affects training and one can make a case for either approach. |
@qgallouedec can you please give a third opinion on this PR? |
Sorry for the late reply. I'm slowly catching up after my vacation
And are there any results to show that this approach gives significantly better results? I've never seen a paper or lib using this approach, do you know of any? |
That's a fair question, its perhaps necessary to quantify the impact of this. I can create a colab notebook to test out the results of a model trained on a dataset where 20% of the prompts are truncated. We can then compare the results on a test set when preserving and not preserving the chat template. Are there any datasets that you'd recommend for this and what model should I be training, is Qwen3-0.6B good enough or do we want to train a larger model? |
We're left truncating, so cutting the system prompt first and foremost. We should see this impact training, given the importance of the system prompt in common GRPO applications.
How about open-r1/DAPO-Math-17k-Processed with Qwen3 1.7B |
What does this PR do?
The GRPOTrainer supports truncating the prompt based on the
max_prompt_length
configuration. Currently prompt truncation happens viahttps://github.com/pramodith/trl/blob/79ec242aefedc108de9edbea62be6d95070fde03/trl/trainer/grpo_trainer.py#L1083-L1087
However if the prompt is in chatML format, a naive truncation will break the template leading to poor training results.
This PR addresses this problem by doing the following:
apply_chat_template
.content
s add up tomax_prompt_length
Since this approach only prunes/truncates tokens in the "content" field of a chat template message, we guarantee the preservation of the chat template.
However the downside of this approach is that the
max_prompt_length
isn't strictly adhered to since we don't account for any tokens thatapply_chat_template
would introduce such as <|im_start|>role, \n and any other things like a default system prompt.I think that this is a fair trade off though. I did initially try out another approach where we could also account for any tokens introduced by the chat template but there were far too many edge cases to consider.
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
@LeonEricsson @qgallouedec
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.