|
15 | 15 | import warnings
|
16 | 16 | from collections import defaultdict
|
17 | 17 | from collections.abc import Sequence
|
| 18 | +from itertools import takewhile |
18 | 19 | from typing import Any, Callable, Optional, TypeVar, Union
|
19 | 20 |
|
20 | 21 | import numpy as np
|
@@ -121,37 +122,32 @@ def apply_chat_template(
|
121 | 122 | prompt_chosen = tokenizer.apply_chat_template(
|
122 | 123 | example["prompt"] + example["chosen"], tools=tools, tokenize=False
|
123 | 124 | )
|
| 125 | + # DeepSeek-R1 inserts a <think> token when using `add_generation_prompt`, which can cause discrepancies |
| 126 | + # between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the |
| 127 | + # common prefix between the two. In most cases, this is a no-op. |
| 128 | + prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen))) |
| 129 | + |
124 | 130 | chosen = prompt_chosen[len(prompt) :]
|
125 | 131 | if "rejected" in example and "prompt" in example: # explicit prompt
|
126 | 132 | prompt_rejected = tokenizer.apply_chat_template(
|
127 | 133 | example["prompt"] + example["rejected"], tools=tools, tokenize=False
|
128 | 134 | )
|
| 135 | + # Handle DeepSeek-R1 <think> token, see the above comment for details |
| 136 | + prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected))) |
129 | 137 | rejected = prompt_rejected[len(prompt) :]
|
130 | 138 | if "completion" in example:
|
131 | 139 | prompt_completion = tokenizer.apply_chat_template(
|
132 | 140 | example["prompt"] + example["completion"], tools=tools, tokenize=False
|
133 | 141 | )
|
| 142 | + # Handle DeepSeek-R1 <think> token, see the above comment for details |
| 143 | + prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion))) |
134 | 144 | completion = prompt_completion[len(prompt) :]
|
135 | 145 | else: # implicit prompt case
|
136 | 146 | if "chosen" in example:
|
137 | 147 | chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False)
|
138 | 148 | if "rejected" in example:
|
139 | 149 | rejected = tokenizer.apply_chat_template(example["rejected"], tools=tools, tokenize=False)
|
140 | 150 |
|
141 |
| - # Ensure that the prompt is the initial part of the prompt-completion string |
142 |
| - if "prompt" in example: |
143 |
| - error_message = ( |
144 |
| - "The chat template applied to the prompt + completion does not start with the chat template applied to " |
145 |
| - "the prompt alone. This can indicate that the chat template is not supported by TRL." |
146 |
| - "\n**Prompt**:\n{}\n\n**Prompt + Completion**:\n{}" |
147 |
| - ) |
148 |
| - if "chosen" in example and not prompt_chosen.startswith(prompt): |
149 |
| - raise ValueError(error_message.format(prompt, prompt_chosen)) |
150 |
| - if "rejected" in example and not prompt_rejected.startswith(prompt): |
151 |
| - raise ValueError(error_message.format(prompt, prompt_rejected)) |
152 |
| - if "completion" in example and not prompt_completion.startswith(prompt): |
153 |
| - raise ValueError(error_message.format(prompt, prompt_completion)) |
154 |
| - |
155 | 151 | # Extract the completion by removing the prompt part from the prompt-completion string
|
156 | 152 | output = {}
|
157 | 153 | if "messages" in example:
|
|
0 commit comments