Skip to content

Commit 6ffde23

Browse files
kashifqgallouedec
andauthored
💭 [Data] Fix DeepSeek-R1 case (#3522)
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 6f288c2 commit 6ffde23

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

tests/test_data_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@ class ApplyChatTemplateTester(unittest.TestCase):
144144
]
145145

146146
non_conversational_examples = [
147-
{"prompt": "The sky is", "completion": " blue."},
148-
{"text": "The sky is blue."},
149-
{"prompt": "The sky is"},
150-
{"prompt": "The sky is", "chosen": " blue.", "rejected": " green."},
151-
{"chosen": "The sky is blue.", "rejected": "The sky is green."},
152-
{"prompt": "The sky is", "completion": " blue.", "label": True},
147+
{"text": "The sky is blue."}, # Language modeling
148+
{"prompt": "The sky is"}, # Prompt only
149+
{"prompt": "The sky is", "completion": " blue."}, # Prompt-completion
150+
{"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, # Preference
151+
{"chosen": "The sky is blue.", "rejected": "The sky is green."}, # Preference with implicit prompt
152+
{"prompt": "The sky is", "completion": " blue.", "label": True}, # Unpaired preference
153153
]
154154

155155
@parameterized.expand(itertools.product(tokenizers, conversational_examples))

trl/data_utils.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import warnings
1616
from collections import defaultdict
1717
from collections.abc import Sequence
18+
from itertools import takewhile
1819
from typing import Any, Callable, Optional, TypeVar, Union
1920

2021
import numpy as np
@@ -121,37 +122,32 @@ def apply_chat_template(
121122
prompt_chosen = tokenizer.apply_chat_template(
122123
example["prompt"] + example["chosen"], tools=tools, tokenize=False
123124
)
125+
# DeepSeek-R1 inserts a <think> token when using `add_generation_prompt`, which can cause discrepancies
126+
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
127+
# common prefix between the two. In most cases, this is a no-op.
128+
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen)))
129+
124130
chosen = prompt_chosen[len(prompt) :]
125131
if "rejected" in example and "prompt" in example: # explicit prompt
126132
prompt_rejected = tokenizer.apply_chat_template(
127133
example["prompt"] + example["rejected"], tools=tools, tokenize=False
128134
)
135+
# Handle DeepSeek-R1 <think> token, see the above comment for details
136+
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
129137
rejected = prompt_rejected[len(prompt) :]
130138
if "completion" in example:
131139
prompt_completion = tokenizer.apply_chat_template(
132140
example["prompt"] + example["completion"], tools=tools, tokenize=False
133141
)
142+
# Handle DeepSeek-R1 <think> token, see the above comment for details
143+
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
134144
completion = prompt_completion[len(prompt) :]
135145
else: # implicit prompt case
136146
if "chosen" in example:
137147
chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False)
138148
if "rejected" in example:
139149
rejected = tokenizer.apply_chat_template(example["rejected"], tools=tools, tokenize=False)
140150

141-
# Ensure that the prompt is the initial part of the prompt-completion string
142-
if "prompt" in example:
143-
error_message = (
144-
"The chat template applied to the prompt + completion does not start with the chat template applied to "
145-
"the prompt alone. This can indicate that the chat template is not supported by TRL."
146-
"\n**Prompt**:\n{}\n\n**Prompt + Completion**:\n{}"
147-
)
148-
if "chosen" in example and not prompt_chosen.startswith(prompt):
149-
raise ValueError(error_message.format(prompt, prompt_chosen))
150-
if "rejected" in example and not prompt_rejected.startswith(prompt):
151-
raise ValueError(error_message.format(prompt, prompt_rejected))
152-
if "completion" in example and not prompt_completion.startswith(prompt):
153-
raise ValueError(error_message.format(prompt, prompt_completion))
154-
155151
# Extract the completion by removing the prompt part from the prompt-completion string
156152
output = {}
157153
if "messages" in example:

0 commit comments

Comments
 (0)