Skip to content

Commit 3412f50

Browse files
committed
fix
1 parent d46655c commit 3412f50

File tree

4 files changed

+210
-5
lines changed

4 files changed

+210
-5
lines changed

paddlenlp/transformers/llama/tokenizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ def __init__(
7272
self.add_bos_token = add_bos_token
7373
self.add_eos_token = add_eos_token
7474
self.decode_with_prefix_space = decode_with_prefix_space
75-
# self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
76-
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
77-
self.sp_model.Load(vocab_file)
75+
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", True))
7876

7977
@property
8078
def vocab_size(self):
@@ -101,7 +99,7 @@ def bos_token_id(self) -> Optional[int]:
10199
def eos_token_id(self) -> Optional[int]:
102100
return self.sp_model.eos_id()
103101

104-
def get_spm_processor(self, from_slow=False):
102+
def get_spm_processor(self, from_slow=True):
105103
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
106104
if from_slow: # no dependency on protobuf
107105
tokenizer.Load(self.vocab_file)

paddlenlp/transformers/tokenizer_utils_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,10 @@ def _from_pretrained(
16001600
from_hf_hub=False,
16011601
**kwargs,
16021602
):
1603-
from_slow = kwargs.get("from_slow", False)
1603+
if cls.__name__.endswith("Fast"):
1604+
from_slow = kwargs.get("from_slow", False)
1605+
else:
1606+
from_slow = kwargs.get("from_slow", True)
16041607
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
16051608
if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None:
16061609
slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2021 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import tempfile
17+
import unittest
18+
19+
from paddlenlp.transformers import SPIECE_UNDERLINE, MBart50Tokenizer
20+
from paddlenlp.transformers.mbart.modeling import shift_tokens_right
21+
22+
from ...testing_utils import get_tests_dir, nested_simplify
23+
from ..test_tokenizer_common import TokenizerTesterMixin
24+
25+
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
26+
27+
EN_CODE = 250004
28+
RO_CODE = 250020
29+
30+
31+
class MBart50TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
32+
tokenizer_class = MBart50Tokenizer
33+
test_sentencepiece = True
34+
35+
test_offsets = False
36+
37+
def setUp(self):
38+
super().setUp()
39+
40+
# We have a SentencePiece fixture for testing
41+
tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True)
42+
tokenizer.save_pretrained(self.tmpdirname)
43+
44+
def test_convert_token_and_id(self):
45+
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
46+
token = "<s>"
47+
token_id = 0
48+
49+
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
50+
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
51+
52+
def test_get_vocab(self):
53+
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
54+
55+
self.assertEqual(vocab_keys[0], "<s>")
56+
self.assertEqual(vocab_keys[1], "<pad>")
57+
self.assertEqual(vocab_keys[-1], "<mask>")
58+
self.assertEqual(len(vocab_keys), 1_054)
59+
60+
def test_vocab_size(self):
61+
self.assertEqual(self.get_tokenizer().vocab_size, 1_054)
62+
63+
def test_full_tokenizer(self):
64+
tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True)
65+
66+
tokens = tokenizer.tokenize("This is a test")
67+
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
68+
69+
self.assertListEqual(
70+
tokenizer.convert_tokens_to_ids(tokens),
71+
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
72+
)
73+
74+
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
75+
self.assertListEqual(
76+
tokens,
77+
# fmt: off
78+
[
79+
SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was",
80+
SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in",
81+
SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",",
82+
SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this",
83+
SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é",
84+
"."
85+
],
86+
# fmt: on
87+
)
88+
ids = tokenizer.convert_tokens_to_ids(tokens)
89+
self.assertListEqual(
90+
ids,
91+
[
92+
value + tokenizer.fairseq_offset
93+
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
94+
],
95+
)
96+
97+
back_tokens = tokenizer.convert_ids_to_tokens(ids)
98+
self.assertListEqual(
99+
back_tokens,
100+
# fmt: off
101+
[
102+
SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was",
103+
SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in",
104+
SPIECE_UNDERLINE + "", "<unk>", "2", "0", "0", "0", ",",
105+
SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this",
106+
SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s",
107+
"<unk>", "."
108+
],
109+
# fmt: on
110+
)
111+
112+
113+
class MBart50OneToManyIntegrationTest(unittest.TestCase):
114+
checkpoint_name = "mbart-large-50-one-to-many-mmt"
115+
src_text = [
116+
" UN Chief Says There Is No Military Solution in Syria",
117+
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
118+
]
119+
tgt_text = [
120+
"Şeful ONU declară că nu există o soluţie militară în Siria",
121+
"Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei"
122+
' pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor'
123+
" face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
124+
]
125+
expected_src_tokens = [EN_CODE, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
126+
127+
@classmethod
128+
def setUpClass(cls):
129+
cls.tokenizer: MBart50Tokenizer = MBart50Tokenizer.from_pretrained(
130+
cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
131+
)
132+
cls.pad_token_id = 1
133+
return cls
134+
135+
def check_language_codes(self):
136+
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001)
137+
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
138+
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
139+
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["mr_IN"], 250038)
140+
141+
def test_tokenizer_decode_ignores_language_codes(self):
142+
self.assertIn(RO_CODE, self.tokenizer.all_special_ids)
143+
generated_ids = [RO_CODE, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
144+
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
145+
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
146+
self.assertEqual(result, expected_romanian)
147+
self.assertNotIn(self.tokenizer.eos_token, result)
148+
149+
def test_tokenizer_truncation(self):
150+
src_text = ["this is gunna be a long sentence " * 20]
151+
assert isinstance(src_text[0], str)
152+
desired_max_length = 10
153+
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
154+
self.assertEqual(ids[0], EN_CODE)
155+
self.assertEqual(ids[-1], 2)
156+
self.assertEqual(len(ids), desired_max_length)
157+
158+
def test_mask_token(self):
159+
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "ar_AR"]), [250053, 250001])
160+
161+
def test_special_tokens_unaffacted_by_save_load(self):
162+
tmpdirname = tempfile.mkdtemp()
163+
original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
164+
self.tokenizer.save_pretrained(tmpdirname)
165+
new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
166+
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
167+
168+
def test_seq2seq_max_target_length(self):
169+
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pd")
170+
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pd")
171+
labels = targets["input_ids"]
172+
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
173+
174+
self.assertEqual(batch.input_ids.shape[1], 3)
175+
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
176+
177+
def test_tokenizer_translation(self):
178+
inputs = self.tokenizer._build_translation_inputs(
179+
"A test", return_tensors="pd", src_lang="en_XX", tgt_lang="ar_AR"
180+
)
181+
182+
self.assertEqual(
183+
nested_simplify(inputs),
184+
{
185+
# en_XX, A, test, EOS
186+
"input_ids": [[250004, 62, 3034, 2]],
187+
"attention_mask": [[1, 1, 1, 1]],
188+
# ar_AR
189+
"forced_bos_token_id": 250001,
190+
},
191+
)

0 commit comments

Comments
 (0)