Skip to content

Commit 9eb3cfe

Browse files
authored
[Tokenizer] Fix tokenizer of llama3.3 (#9641)
* fix tokenizer of llama3 and add test case * fix paddle.where
1 parent dc0ca03 commit 9eb3cfe

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1601,7 +1601,8 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
16011601
expanded_attn_mask = expanded_attn_mask.astype(dtype)
16021602
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
16031603
else:
1604-
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
1604+
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min)
1605+
expanded_attn_mask = expanded_attn_mask.astype(dtype)
16051606
return expanded_attn_mask
16061607

16071608
@paddle.jit.not_to_static

paddlenlp/transformers/llama/tokenizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,11 @@ def __init__(
340340
self.eos_token = ENDOFTEXT
341341
self.bos_token_id = self.bod_id
342342
self.eos_token_id = self.eod_id
343-
self.pad_token = self.convert_ids_to_tokens(self.eos_token_id)
343+
if "pad_token" not in kwargs:
344+
self.pad_token = self.convert_ids_to_tokens(self.eos_token_id)
345+
kwargs["pad_token"] = self.pad_token
344346

345-
super().__init__(pad_token=self.pad_token, **kwargs)
347+
super().__init__(**kwargs)
346348

347349
def __len__(self) -> int:
348350
return self.tokenizer.n_vocab

tests/transformers/llama/test_tokenizer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import tempfile
1818
import unittest
1919

20+
from parameterized import parameterized_class
21+
2022
from paddlenlp.transformers.auto.tokenizer import AutoTokenizer
2123
from paddlenlp.transformers.llama.tokenizer import LlamaTokenizer
2224
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
@@ -213,6 +215,30 @@ def test_pretrained_model_lists(self):
213215
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_resource_files_map.values())[0]), 1)
214216

215217

218+
@parameterized_class(
219+
["model_name_or_path"],
220+
[
221+
["facebook/llama-7b"],
222+
["meta-llama/Meta-Llama-3.1-8B"],
223+
["meta-llama/Llama-3.2-1B"],
224+
["meta-llama/Llama-3.3-70B-Instruct"],
225+
],
226+
)
227+
class LlamaTokenizationLoadTest(unittest.TestCase):
228+
model_name_or_path: str = None
229+
230+
def get_tokenizer(self, **kwargs) -> PretrainedTokenizer:
231+
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, **kwargs)
232+
if tokenizer.pad_token is None:
233+
tokenizer.pad_token = tokenizer.unk_token
234+
return tokenizer
235+
236+
def test_load_tokenizer(self):
237+
tokenizer = self.get_tokenizer()
238+
text = "lower newer"
239+
tokenizer.tokenize(text, add_prefix_space=True)
240+
241+
216242
class TikTokenIntegrationTests(unittest.TestCase):
217243
"""
218244
A class that regroups important test to make sure that we properly handle the special tokens.

0 commit comments

Comments
 (0)