Skip to content

Ensure Chat Template Safe Prompt Truncation #3646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,3 +1271,205 @@ def test_training_with_generation_kwargs(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_with_prompt_truncation(self):
"""
Test that training works with prompt truncation.
This is a regression test for a bug where the trainer would not handle prompt truncation correctly.
"""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_prompt_length=128, # reduce the prompt length to test truncation
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")


class TruncatePromptTester(unittest.TestCase):
def setUp(self):
self.dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.reward_model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"

def create_trainer(self, max_steps=None, tmp_dir=None, **kwargs):
args_kwargs = dict(
output_dir=tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=8,
report_to="none",
bf16=False,
**kwargs,
)
if max_steps is not None:
args_kwargs["max_steps"] = max_steps
training_args = GRPOConfig(**args_kwargs)
return GRPOTrainer(
model=self.model_id,
reward_funcs=self.reward_model_id,
args=training_args,
train_dataset=self.dataset,
)

def test_truncate_prompt_standard(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = self.create_trainer(tmp_dir=tmp_dir, max_prompt_length=4)
# Simulate non-conversational: prompt_ids shape (2, 6)
prompts = [{"prompt": "a b c d e f"}, {"prompt": "g h i j k l"}]
prompt_inputs = trainer.processing_class([x["prompt"] for x in prompts], return_tensors="pt")
prompt_ids = prompt_inputs["input_ids"]
# is_dataset_conversational returns False
truncated_ids, truncated_mask, truncated_text = trainer._get_prompt_inputs(prompts)
self.assertEqual(truncated_ids.shape, (2, 4))
self.assertEqual(truncated_mask.shape, (2, 4))
# Should keep only last 4 tokens
self.assertTrue(torch.equal(truncated_ids, torch.stack([prompt_ids[0][-4:], prompt_ids[1][-4:]])))
# Decoded text should match the truncated ids
self.assertEqual(truncated_text, [" c d e f", " i j k l"])

def test_truncate_prompt_no_truncation(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = self.create_trainer(tmp_dir=tmp_dir, max_prompt_length=10)
prompts = [{"prompt": "a b c"}, {"prompt": "d e f"}]
prompt_inputs = trainer.processing_class([x["prompt"] for x in prompts], return_tensors="pt")
prompt_ids = prompt_inputs["input_ids"]
prompt_mask = prompt_inputs["attention_mask"]
truncated_ids, truncated_mask, truncated_text = trainer._get_prompt_inputs(prompts)
# Should be unchanged
self.assertTrue(torch.equal(truncated_ids, prompt_ids))
self.assertTrue(torch.equal(truncated_mask, prompt_mask))
self.assertEqual(truncated_text, [x["prompt"] for x in prompts])

def test_truncate_prompt_conversational(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = self.create_trainer(tmp_dir=tmp_dir, max_prompt_length=15)
prompts = [
# Will have 21 non-special tokens + 6 <|im_start|> +
# 6 role + 6 <|im_end|> + 12\n tokens (introduced by chat template)
{
"prompt": [
{"role": "system", "content": "one"},
{"role": "user", "content": "two two"},
{"role": "assistant", "content": "three three three"},
{"role": "user", "content": "four four four four"},
{"role": "assistant", "content": "five five five five five"},
{"role": "user", "content": "six six six six six six"},
]
}
]
truncated_ids, truncated_mask, truncated_text = trainer._get_prompt_inputs(prompts)

# last three turns + final assistant start tags add up to 21 tokens with chat template preserved.
expected_prompt_message = [
[
{"role": "user", "content": "four four four four"},
{"role": "assistant", "content": "five five five five five"},
{"role": "user", "content": "six six six six six six"},
]
]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text

# A max length of 13 will cut-off 2 "fours" from
# from {"role": "user", "content": "four four"},
trainer.max_prompt_length = 13
_, _, truncated_text = trainer._get_prompt_inputs(prompts)
expected_prompt_message = [
[
{"role": "user", "content": " four four"},
{"role": "assistant", "content": "five five five five five"},
{"role": "user", "content": "six six six six six six"},
]
]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text

# A max length of 8 will cut-off 3 "fives" from
# from {"role": "assistant", "content": "five five five five five"},
trainer.max_prompt_length = 8
_, _, truncated_text = trainer._get_prompt_inputs(prompts)
expected_prompt_message = [
[
{"role": "assistant", "content": " five five"},
{"role": "user", "content": "six six six six six six"},
]
]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text

# A max length of 1 will cut-off all the messages except the last "six"
# in the last turn
trainer.max_prompt_length = 1
_, _, truncated_text = trainer._get_prompt_inputs(prompts)
expected_prompt_message = [[{"role": "user", "content": " six"}]]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text

# A test case for batch size > 1
prompts = [
{
"prompt": [
{"role": "system", "content": "one"},
{"role": "user", "content": "two two"},
{"role": "assistant", "content": "three three three"},
{"role": "user", "content": "four four four four"},
{"role": "assistant", "content": "five five five five five"},
{"role": "user", "content": "six six six six six six"},
]
},
{
"prompt": [
{"role": "system", "content": "one"},
{"role": "user", "content": "two two"},
{"role": "assistant", "content": "three three three"},
{"role": "user", "content": "four four four four"},
]
},
]
trainer.max_prompt_length = 9
_, _, truncated_text = trainer._get_prompt_inputs(prompts)
expected_prompt_message = [
[
{"role": "assistant", "content": " five five five"},
{"role": "user", "content": "six six six six six six"},
],
[
{"role": "user", "content": "two two"},
{"role": "assistant", "content": "three three three"},
{"role": "user", "content": "four four four four"},
],
]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text
9 changes: 7 additions & 2 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class GRPOConfig(TrainingArguments):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. If the
prompt is conversational, we only truncate the message tokens starting from the top of the conversation.
and do not account for any tokens introduced by the chat template.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size
* gradient_accumulation_steps) must be evenly divisible by this value.
Expand Down Expand Up @@ -260,7 +262,10 @@ class GRPOConfig(TrainingArguments):
max_prompt_length: Optional[int] = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. If the "
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation
"and do not account for any tokens introduced
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation
"and do not account for any tokens introduced
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation "
"and do not account for any tokens introduced "

"by the chat template."
},
)
num_generations: Optional[int] = field(
Expand Down
77 changes: 63 additions & 14 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from transformers.trainer_utils import seed_worker
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..data_utils import apply_chat_template, is_conversational
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
Expand Down Expand Up @@ -1066,26 +1066,75 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
rewards_per_func = gather(rewards_per_func)
return rewards_per_func

def _get_prompt_inputs(self, prompts: Union[list[str], list[list[dict[str, str]]]]) -> tuple:
# Checks if the prompt is conversational or not and truncates the input prompt.
# If it is conversational the truncation preserves the chat template.
if not is_conversational(prompts[0]):
prompt_text = [x["prompt"] for x in prompts]
prompt_inputs = self.processing_class(
text=prompt_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if self.max_prompt_length is not None and prompt_mask.sum(-1).max() > self.max_prompt_length:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you motivate the need for prompt_mask.sum(-1).max() > self.max_prompt_length? is it to avoid unnecessary decode if we don't need to truncate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, gets rid of any redundant ops.

Copy link
Collaborator

@LeonEricsson LeonEricsson Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Given we've already padded, won't prompt_mask.sum(-1).max() always equal prompt_ids.shape[-1] or prompt_mask.shape[-1] (we're doing 'longest' padding - pad to the longest sequence in the batch)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's fair. I'll change that to just use the shape.

prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
prompt_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=False)
else:
# Get the token counts of the content of each message
messages_token_counts = [
[
self.processing_class(msg["content"], add_special_tokens=False, return_tensors="pt")[
"attention_mask"
]
.sum()
.item()
for msg in prompts[i]["prompt"]
]
for i in range(len(prompts))
]
# Compute the number of tokens that the contents of all the messages in a prompt consume
prompts_token_count = [sum(prompt_token_count) for prompt_token_count in messages_token_counts]
truncated_messages = []
for i in range(len(prompts)):
if prompts_token_count[i] <= self.max_prompt_length:
truncated_messages.append(prompts[i])
else:
num_tokens_to_truncate = prompts_token_count[i] - self.max_prompt_length
truncated_messages.append([])
for ind, msg in enumerate(prompts[i]["prompt"]):
if num_tokens_to_truncate == 0:
truncated_messages[-1].append(msg)
else:
if messages_token_counts[i][ind] <= num_tokens_to_truncate:
num_tokens_to_truncate -= messages_token_counts[i][ind]
else:
tokens = self.processing_class(msg["content"], add_special_tokens=False)["input_ids"]
tokens = tokens[num_tokens_to_truncate:]
truncated_message = self.processing_class.decode(tokens)
msg["content"] = truncated_message
num_tokens_to_truncate = 0
truncated_messages[-1].append(msg)

prompt_inputs = self.processing_class.apply_chat_template(
truncated_messages, return_dict=True, add_generation_prompt=True
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids = prompt_inputs["input_ids"]
prompt_mask = prompt_inputs["attention_mask"]
prompt_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be outside/after the for-loop, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good catch.


return prompt_ids, prompt_mask, prompt_text

def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
mode = "train" if self.model.training else "eval"

prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
prompt_ids, prompt_mask, prompts_text = self._get_prompt_inputs(inputs)

# Generate completions using either vLLM or regular generation
if self.use_vllm:
Expand Down