-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Ensure Chat Template Safe Prompt Truncation #3646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
4f2be24
45fd454
734a311
3ab9a48
995628d
e887190
c6176d8
660cddb
763b86a
9d3383c
26141c1
1a9b6f4
9008389
f0d50c7
9359107
c0821af
7f09ad0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,7 +46,7 @@ | |
from transformers.trainer_utils import seed_worker | ||
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available | ||
|
||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template | ||
from ..data_utils import apply_chat_template, is_conversational | ||
from ..extras.profiling import profiling_context, profiling_decorator | ||
from ..extras.vllm_client import VLLMClient | ||
from ..import_utils import is_liger_kernel_available, is_vllm_available | ||
|
@@ -1066,26 +1066,75 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): | |
rewards_per_func = gather(rewards_per_func) | ||
return rewards_per_func | ||
|
||
def _get_prompt_inputs(self, prompts: Union[list[str], list[list[dict[str, str]]]]) -> tuple: | ||
# Checks if the prompt is conversational or not and truncates the input prompt. | ||
# If it is conversational the truncation preserves the chat template. | ||
if not is_conversational(prompts[0]): | ||
prompt_text = [x["prompt"] for x in prompts] | ||
prompt_inputs = self.processing_class( | ||
text=prompt_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False | ||
) | ||
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] | ||
|
||
if self.max_prompt_length is not None and prompt_mask.sum(-1).max() > self.max_prompt_length: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you motivate the need for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes exactly, gets rid of any redundant ops. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. Given we've already padded, won't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's fair. I'll change that to just use the shape. |
||
prompt_ids = prompt_ids[:, -self.max_prompt_length :] | ||
prompt_mask = prompt_mask[:, -self.max_prompt_length :] | ||
prompt_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=False) | ||
else: | ||
# Get the token counts of the content of each message | ||
messages_token_counts = [ | ||
[ | ||
self.processing_class(msg["content"], add_special_tokens=False, return_tensors="pt")[ | ||
"attention_mask" | ||
] | ||
.sum() | ||
.item() | ||
for msg in prompts[i]["prompt"] | ||
] | ||
for i in range(len(prompts)) | ||
] | ||
# Compute the number of tokens that the contents of all the messages in a prompt consume | ||
prompts_token_count = [sum(prompt_token_count) for prompt_token_count in messages_token_counts] | ||
truncated_messages = [] | ||
for i in range(len(prompts)): | ||
if prompts_token_count[i] <= self.max_prompt_length: | ||
truncated_messages.append(prompts[i]) | ||
else: | ||
num_tokens_to_truncate = prompts_token_count[i] - self.max_prompt_length | ||
truncated_messages.append([]) | ||
for ind, msg in enumerate(prompts[i]["prompt"]): | ||
if num_tokens_to_truncate == 0: | ||
truncated_messages[-1].append(msg) | ||
else: | ||
if messages_token_counts[i][ind] <= num_tokens_to_truncate: | ||
num_tokens_to_truncate -= messages_token_counts[i][ind] | ||
else: | ||
tokens = self.processing_class(msg["content"], add_special_tokens=False)["input_ids"] | ||
tokens = tokens[num_tokens_to_truncate:] | ||
truncated_message = self.processing_class.decode(tokens) | ||
msg["content"] = truncated_message | ||
num_tokens_to_truncate = 0 | ||
truncated_messages[-1].append(msg) | ||
|
||
prompt_inputs = self.processing_class.apply_chat_template( | ||
truncated_messages, return_dict=True, add_generation_prompt=True | ||
) | ||
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
prompt_ids = prompt_inputs["input_ids"] | ||
prompt_mask = prompt_inputs["attention_mask"] | ||
prompt_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be outside/after the for-loop, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, good catch. |
||
|
||
return prompt_ids, prompt_mask, prompt_text | ||
|
||
def _generate_and_score_completions( | ||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]] | ||
) -> dict[str, Union[torch.Tensor, Any]]: | ||
device = self.accelerator.device | ||
mode = "train" if self.model.training else "eval" | ||
|
||
prompts = [x["prompt"] for x in inputs] | ||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] | ||
prompt_inputs = self.processing_class( | ||
text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False | ||
) | ||
prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] | ||
|
||
if self.max_prompt_length is not None: | ||
prompt_ids = prompt_ids[:, -self.max_prompt_length :] | ||
prompt_mask = prompt_mask[:, -self.max_prompt_length :] | ||
prompts_text = self.processing_class.batch_decode( | ||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False | ||
) | ||
prompt_ids, prompt_mask, prompts_text = self._get_prompt_inputs(inputs) | ||
LeonEricsson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Generate completions using either vLLM or regular generation | ||
if self.use_vllm: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.