File tree Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change 13
13
# limitations under the License.
14
14
15
15
import os
16
+ import re
16
17
import textwrap
17
18
import warnings
18
19
from collections import defaultdict , deque
@@ -1153,11 +1154,17 @@ def _generate_and_score_completions(
1153
1154
prompt_ids , prompt_mask = prompt_inputs ["input_ids" ], prompt_inputs ["attention_mask" ]
1154
1155
1155
1156
if self .max_prompt_length is not None :
1157
+ # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
1158
+ # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
1159
+ # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
1156
1160
prompt_ids = prompt_ids [:, - self .max_prompt_length :]
1157
1161
prompt_mask = prompt_mask [:, - self .max_prompt_length :]
1158
1162
prompts_text = self .processing_class .batch_decode (
1159
1163
prompt_ids , skip_special_tokens = False , clean_up_tokenization_spaces = False
1160
1164
)
1165
+ prompts_text = [
1166
+ re .sub (rf"^({ re .escape (self .processing_class .pad_token )} )+" , "" , text ) for text in prompts_text
1167
+ ]
1161
1168
1162
1169
# Generate completions using either vLLM or regular generation
1163
1170
if self .use_vllm :
You can’t perform that action at this time.
0 commit comments