Skip to content

💭 [Data] Fix DeepSeek-R1 case #3522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 4, 2025
12 changes: 6 additions & 6 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ class ApplyChatTemplateTester(unittest.TestCase):
]

non_conversational_examples = [
{"prompt": "The sky is", "completion": " blue."},
{"text": "The sky is blue."},
{"prompt": "The sky is"},
{"prompt": "The sky is", "chosen": " blue.", "rejected": " green."},
{"chosen": "The sky is blue.", "rejected": "The sky is green."},
{"prompt": "The sky is", "completion": " blue.", "label": True},
{"text": "The sky is blue."}, # Language modeling
{"prompt": "The sky is"}, # Prompt only
{"prompt": "The sky is", "completion": " blue."}, # Prompt-completion
{"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, # Preference
{"chosen": "The sky is blue.", "rejected": "The sky is green."}, # Preference with implicit prompt
{"prompt": "The sky is", "completion": " blue.", "label": True}, # Unpaired preference
]

@parameterized.expand(itertools.product(tokenizers, conversational_examples))
Expand Down
24 changes: 10 additions & 14 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import warnings
from collections import defaultdict
from collections.abc import Sequence
from itertools import takewhile
from typing import Any, Callable, Optional, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -121,37 +122,32 @@ def apply_chat_template(
prompt_chosen = tokenizer.apply_chat_template(
example["prompt"] + example["chosen"], tools=tools, tokenize=False
)
# DeepSeek-R1 inserts a <think> token when using `add_generation_prompt`, which can cause discrepancies
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
# common prefix between the two. In most cases, this is a no-op.
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen)))

chosen = prompt_chosen[len(prompt) :]
if "rejected" in example and "prompt" in example: # explicit prompt
prompt_rejected = tokenizer.apply_chat_template(
example["prompt"] + example["rejected"], tools=tools, tokenize=False
)
# Handle DeepSeek-R1 <think> token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
rejected = prompt_rejected[len(prompt) :]
if "completion" in example:
prompt_completion = tokenizer.apply_chat_template(
example["prompt"] + example["completion"], tools=tools, tokenize=False
)
# Handle DeepSeek-R1 <think> token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
completion = prompt_completion[len(prompt) :]
else: # implicit prompt case
if "chosen" in example:
chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False)
if "rejected" in example:
rejected = tokenizer.apply_chat_template(example["rejected"], tools=tools, tokenize=False)

# Ensure that the prompt is the initial part of the prompt-completion string
if "prompt" in example:
error_message = (
"The chat template applied to the prompt + completion does not start with the chat template applied to "
"the prompt alone. This can indicate that the chat template is not supported by TRL."
"\n**Prompt**:\n{}\n\n**Prompt + Completion**:\n{}"
)
if "chosen" in example and not prompt_chosen.startswith(prompt):
raise ValueError(error_message.format(prompt, prompt_chosen))
if "rejected" in example and not prompt_rejected.startswith(prompt):
raise ValueError(error_message.format(prompt, prompt_rejected))
if "completion" in example and not prompt_completion.startswith(prompt):
raise ValueError(error_message.format(prompt, prompt_completion))

Comment on lines -141 to -154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't occur anymore

# Extract the completion by removing the prompt part from the prompt-completion string
output = {}
if "messages" in example:
Expand Down
Loading