Skip to content

Ensure Chat Template Safe Prompt Truncation #3646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

pramodith
Copy link
Collaborator

@pramodith pramodith commented Jun 25, 2025

What does this PR do?

The GRPOTrainer supports truncating the prompt based on the max_prompt_length configuration. Currently prompt truncation happens via

https://github.com/pramodith/trl/blob/79ec242aefedc108de9edbea62be6d95070fde03/trl/trainer/grpo_trainer.py#L1083-L1087

However if the prompt is in chatML format, a naive truncation will break the template leading to poor training results.

This PR addresses this problem by doing the following:

  1. Computes the length of each turn in a conversation excluding any special tokens or tokens added by apply_chat_template.
  2. Creates a budget of the number of tokens that need to be truncated.
  3. Iterates through each turn and determines if the turn should be completely excluded, truncated or included without truncation.
  4. We will retain the last k turns whose contents add up to max_prompt_length

Since this approach only prunes/truncates tokens in the "content" field of a chat template message, we guarantee the preservation of the chat template.

However the downside of this approach is that the max_prompt_length isn't strictly adhered to since we don't account for any tokens that apply_chat_template would introduce such as <|im_start|>role, \n and any other things like a default system prompt.

I think that this is a fair trade off though. I did initially try out another approach where we could also account for any tokens introduced by the chat template but there were far too many edge cases to consider.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@LeonEricsson @qgallouedec
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@pramodith pramodith marked this pull request as draft June 25, 2025 11:15
@LeonEricsson
Copy link
Collaborator

Do you plan to split the skip_special_tokens=False fix into a separate PR? I think it makes sense to separate them – this way, we can get the fix merged as soon as possible.

@pramodith
Copy link
Collaborator Author

Cool, I'll have a new PR with just skip_special_tokens=False in the next hour.

@pramodith pramodith changed the title Prompt Decoding after truncation should not skip special tokens Ensure Chat Template Safe Prompt Truncation Jun 27, 2025
@pramodith pramodith marked this pull request as ready for review June 27, 2025 11:17
@pramodith
Copy link
Collaborator Author

@LeonEricsson this one is ready for review now. I've mentioned some of the tradeoffs this solution entails in the description of the PR.

Copy link
Collaborator

@LeonEricsson LeonEricsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested running with the following conversational dataset:

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

"""
def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset("openai/gsm8k", "main")[split]  # type: ignore
    data = data.map(
        lambda x: {  # type: ignore
            "prompt": [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": x["question"]}],
            "answer": extract_hash_answer(x["answer"]),
        }
    )  # type: ignore
    return data  # type: ignore

and the prompts_text returned by _get_prompt_inputs had the wrong system prompt. They contained the tokenizer's default system prompt instead.

Comment on lines 266 to 267
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation
"and do not account for any tokens introduced
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation
"and do not account for any tokens introduced
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation "
"and do not account for any tokens introduced "

prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if self.max_prompt_length is not None and prompt_mask.sum(-1).max() > self.max_prompt_length:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you motivate the need for prompt_mask.sum(-1).max() > self.max_prompt_length? is it to avoid unnecessary decode if we don't need to truncate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, gets rid of any redundant ops.

Copy link
Collaborator

@LeonEricsson LeonEricsson Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Given we've already padded, won't prompt_mask.sum(-1).max() always equal prompt_ids.shape[-1] or prompt_mask.shape[-1] (we're doing 'longest' padding - pad to the longest sequence in the batch)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's fair. I'll change that to just use the shape.

Comment on lines 1120 to 1126
prompt_inputs = self.processing_class.apply_chat_template(
truncated_messages, return_dict=True, add_generation_prompt=True
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids = prompt_inputs["input_ids"]
prompt_mask = prompt_inputs["attention_mask"]
prompt_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be outside/after the for-loop, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good catch.

@pramodith
Copy link
Collaborator Author

pramodith commented Jul 1, 2025

I tested running with the following conversational dataset:

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

"""
def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset("openai/gsm8k", "main")[split]  # type: ignore
    data = data.map(
        lambda x: {  # type: ignore
            "prompt": [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": x["question"]}],
            "answer": extract_hash_answer(x["answer"]),
        }
    )  # type: ignore
    return data  # type: ignore

and the prompts_text returned by _get_prompt_inputs had the wrong system prompt. They contained the tokenizer's default system prompt instead.

What was the max_prompt_length configured to? Some chat templates automatically introduce the default system prompt if the system message is missing in the messages for a given batch. This happens when running the apply_chat_template function.

So I'm assuming that the max_prompt_length that was set for this example eliminated the custom system prompt. Now that I think about it this means that any custom system prompt will always be pruned. Might make more sense for us to start pruning from the first user message instead of starting from the system prompt and only prune from the system prompt in case all turns barring the last user message was pruned.

But I suppose that would make this approach inconsistent with how the non-chat prompt truncation is handled. Let me know what you think the best approach is.

@LeonEricsson
Copy link
Collaborator

LeonEricsson commented Jul 2, 2025

What was the max_prompt_length configured to? Some chat templates automatically introduce the default system prompt if the system message is missing in the messages for a given batch. This happens when running the apply_chat_template function.

yes good catch, i believe this was indeed the case.

So I'm assuming that the max_prompt_length that was set for this example eliminated the custom system prompt. Now that I think about it this means that any custom system prompt will always be pruned. Might make more sense for us to start pruning from the first user message instead of starting from the system prompt and only prune from the system prompt in case all turns barring the last user message was pruned.

But I suppose that would make this approach inconsistent with how the non-chat prompt truncation is handled. Let me know what you think the best approach is.

tbh, I'm not entirely sure what the correct approach is here. I feel like user content should indeed be truncated first. However, I agree that we're moving into territory where handling conversational versus non-conversational formats significantly diverges.

@pramodith
Copy link
Collaborator Author

tbh, I'm not entirely sure what the correct approach is here. I feel like user content should indeed be truncated first. However, I agree that we're moving into territory where handling conversational versus non-conversational formats significantly diverges.

Do we want to rope in someone else for a third opinion? I don't think there's a right approach/answer here prompt truncation is going to always be less than ideal in terms of how it affects training and one can make a case for either approach.

@pramodith
Copy link
Collaborator Author

@qgallouedec can you please give a third opinion on this PR?

@qgallouedec
Copy link
Member

Sorry for the late reply. I'm slowly catching up after my vacation
I agree with @LeonEricsson, I understand the intuition of not breaking the chat template, but do we have solid results to validate this?:

a naive truncation will break the template leading to poor training results.

And are there any results to show that this approach gives significantly better results? I've never seen a paper or lib using this approach, do you know of any?

@pramodith
Copy link
Collaborator Author

And are there any results to show that this approach gives significantly better results? I've never seen a paper or lib using this approach, do you know of any?

That's a fair question, its perhaps necessary to quantify the impact of this. I can create a colab notebook to test out the results of a model trained on a dataset where 20% of the prompts are truncated. We can then compare the results on a test set when preserving and not preserving the chat template.

Are there any datasets that you'd recommend for this and what model should I be training, is Qwen3-0.6B good enough or do we want to train a larger model?

@LeonEricsson
Copy link
Collaborator

We're left truncating, so cutting the system prompt first and foremost. We should see this impact training, given the importance of the system prompt in common GRPO applications.

Are there any datasets that you'd recommend for this and what model should I be training, is Qwen3-0.6B good enough or do we want to train a larger model?

How about open-r1/DAPO-Math-17k-Processed with Qwen3 1.7B

@pramodith pramodith closed this Aug 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants