Skip to content

Ensure Chat Template Safe Prompt Truncation #3646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,41 @@ def test_training_with_generation_kwargs(self):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_with_prompt_truncation(self):
"""
Test that training works with prompt truncation.
This is a regression test for a bug where the trainer would not handle prompt truncation correctly.
"""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_prompt_length=128, # reduce the prompt length to test truncation
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_with_reward_func_accessing_trainer_state(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand All @@ -1297,3 +1332,169 @@ def reward_func(completions, **kwargs):
train_dataset=dataset,
)
trainer.train()

class TruncatePromptTester(unittest.TestCase):
def setUp(self):
self.dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.reward_model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"

def create_trainer(self, max_steps=None, tmp_dir=None, **kwargs):
args_kwargs = dict(
output_dir=tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=8,
report_to="none",
bf16=False,
**kwargs,
)
if max_steps is not None:
args_kwargs["max_steps"] = max_steps
training_args = GRPOConfig(**args_kwargs)
return GRPOTrainer(
model=self.model_id,
reward_funcs=self.reward_model_id,
args=training_args,
train_dataset=self.dataset,
)

def test_truncate_prompt_standard(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = self.create_trainer(tmp_dir=tmp_dir, max_prompt_length=4)
# Simulate non-conversational: prompt_ids shape (2, 6)
prompts = [{"prompt": "a b c d e f"}, {"prompt": "g h i j k l"}]
prompt_inputs = trainer.processing_class([x["prompt"] for x in prompts], return_tensors="pt")
prompt_ids = prompt_inputs["input_ids"]
# is_dataset_conversational returns False
truncated_ids, truncated_mask, truncated_text = trainer._get_prompt_inputs(prompts)
self.assertEqual(truncated_ids.shape, (2, 4))
self.assertEqual(truncated_mask.shape, (2, 4))
# Should keep only last 4 tokens
self.assertTrue(torch.equal(truncated_ids, torch.stack([prompt_ids[0][-4:], prompt_ids[1][-4:]])))
# Decoded text should match the truncated ids
self.assertEqual(truncated_text, [" c d e f", " i j k l"])

def test_truncate_prompt_no_truncation(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = self.create_trainer(tmp_dir=tmp_dir, max_prompt_length=10)
prompts = [{"prompt": "a b c"}, {"prompt": "d e f"}]
prompt_inputs = trainer.processing_class([x["prompt"] for x in prompts], return_tensors="pt")
prompt_ids = prompt_inputs["input_ids"]
prompt_mask = prompt_inputs["attention_mask"]
truncated_ids, truncated_mask, truncated_text = trainer._get_prompt_inputs(prompts)
# Should be unchanged
self.assertTrue(torch.equal(truncated_ids, prompt_ids))
self.assertTrue(torch.equal(truncated_mask, prompt_mask))
self.assertEqual(truncated_text, [x["prompt"] for x in prompts])

def test_truncate_prompt_conversational(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = self.create_trainer(tmp_dir=tmp_dir, max_prompt_length=15)
prompts = [
# Will have 21 non-special tokens + 6 <|im_start|> +
# 6 role + 6 <|im_end|> + 12\n tokens (introduced by chat template)
{
"prompt": [
{"role": "system", "content": "one"},
{"role": "user", "content": "two two"},
{"role": "assistant", "content": "three three three"},
{"role": "user", "content": "four four four four"},
{"role": "assistant", "content": "five five five five five"},
{"role": "user", "content": "six six six six six six"},
]
}
]
truncated_ids, truncated_mask, truncated_text = trainer._get_prompt_inputs(prompts)

# last three turns + final assistant start tags add up to 21 tokens with chat template preserved.
expected_prompt_message = [
[
{"role": "user", "content": "four four four four"},
{"role": "assistant", "content": "five five five five five"},
{"role": "user", "content": "six six six six six six"},
]
]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text

# A max length of 13 will cut-off 2 "fours" from
# from {"role": "user", "content": "four four"},
trainer.max_prompt_length = 13
_, _, truncated_text = trainer._get_prompt_inputs(prompts)
expected_prompt_message = [
[
{"role": "user", "content": " four four"},
{"role": "assistant", "content": "five five five five five"},
{"role": "user", "content": "six six six six six six"},
]
]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text

# A max length of 8 will cut-off 3 "fives" from
# from {"role": "assistant", "content": "five five five five five"},
trainer.max_prompt_length = 8
_, _, truncated_text = trainer._get_prompt_inputs(prompts)
expected_prompt_message = [
[
{"role": "assistant", "content": " five five"},
{"role": "user", "content": "six six six six six six"},
]
]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text

# A max length of 1 will cut-off all the messages except the last "six"
# in the last turn
trainer.max_prompt_length = 1
_, _, truncated_text = trainer._get_prompt_inputs(prompts)
expected_prompt_message = [[{"role": "user", "content": " six"}]]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text

# A test case for batch size > 1
prompts = [
{
"prompt": [
{"role": "system", "content": "one"},
{"role": "user", "content": "two two"},
{"role": "assistant", "content": "three three three"},
{"role": "user", "content": "four four four four"},
{"role": "assistant", "content": "five five five five five"},
{"role": "user", "content": "six six six six six six"},
]
},
{
"prompt": [
{"role": "system", "content": "one"},
{"role": "user", "content": "two two"},
{"role": "assistant", "content": "three three three"},
{"role": "user", "content": "four four four four"},
]
},
]
trainer.max_prompt_length = 9
_, _, truncated_text = trainer._get_prompt_inputs(prompts)
expected_prompt_message = [
[
{"role": "assistant", "content": " five five five"},
{"role": "user", "content": "six six six six six six"},
],
[
{"role": "user", "content": "two two"},
{"role": "assistant", "content": "three three three"},
{"role": "user", "content": "four four four four"},
],
]
expected_truncated_text = trainer.processing_class.apply_chat_template(
expected_prompt_message, tokenize=False, add_generation_prompt=True
)
assert truncated_text == expected_truncated_text
8 changes: 6 additions & 2 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class GRPOConfig(TrainingArguments):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. If the
prompt is conversational, we only truncate the message tokens starting from the top of the conversation.
and do not account for any tokens introduced by the chat template.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size
* gradient_accumulation_steps) must be evenly divisible by this value.
Expand Down Expand Up @@ -260,7 +262,9 @@ class GRPOConfig(TrainingArguments):
max_prompt_length: Optional[int] = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. If the "
"prompt is conversational, we only truncate the message tokens starting from the top of the conversation "
"and do not account for any tokens introduced by the chat template."
},
)
num_generations: Optional[int] = field(
Expand Down
Loading