Skip to content

Commit 364754d

Browse files
pramodithshirinyamanikashifqgallouedec
authored andcommitted
✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text (huggingface#3698)
Co-authored-by: Shirin Yamani <[email protected]> Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent a310042 commit 364754d

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
import re
1617
import textwrap
1718
import warnings
1819
from collections import defaultdict, deque
@@ -1153,11 +1154,17 @@ def _generate_and_score_completions(
11531154
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
11541155

11551156
if self.max_prompt_length is not None:
1157+
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
1158+
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
1159+
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
11561160
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
11571161
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
11581162
prompts_text = self.processing_class.batch_decode(
11591163
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
11601164
)
1165+
prompts_text = [
1166+
re.sub(rf"^({re.escape(self.processing_class.pad_token)})+", "", text) for text in prompts_text
1167+
]
11611168

11621169
# Generate completions using either vLLM or regular generation
11631170
if self.use_vllm:

0 commit comments

Comments
 (0)