Skip to content

Commit ec25cb8

Browse files
authored
[Tokenizer] Support reading Tiktoken tokenizer.model. (#9215)
* add support of tiktoken tokenizer, refactor some code * add case of built-in tokenizers to handle CI error
1 parent f5f9d85 commit ec25cb8

File tree

19 files changed

+1240
-184
lines changed

19 files changed

+1240
-184
lines changed

paddlenlp/transformers/albert/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from .tokenizer import *

paddlenlp/transformers/albert/tokenizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020

2121
import sentencepiece as spm
2222

23-
from .. import PretrainedTokenizer, BertTokenizer, AddedToken
23+
from .. import AddedToken, BertTokenizer, PretrainedTokenizer
2424

25-
__all__ = ["AlbertTokenizer"]
25+
__all__ = ["AlbertTokenizer", "AlbertChineseTokenizer", "AlbertEnglishTokenizer"]
2626

2727
SPIECE_UNDERLINE = "▁"
2828

paddlenlp/transformers/auto/configuration.py

Lines changed: 287 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import importlib
1617
import inspect
1718
import io
1819
import json
1920
import os
20-
from collections import defaultdict
21+
from collections import OrderedDict, defaultdict
2122
from typing import Dict, List, Type
2223

2324
from ...utils.download import resolve_file_path
@@ -30,6 +31,250 @@
3031
"AutoConfig",
3132
]
3233

34+
CONFIG_MAPPING_NAMES = OrderedDict(
35+
[
36+
("albert", "AlbertConfig"),
37+
("artist", "ArtistConfig"),
38+
("bart", "BartConfig"),
39+
("bert", "BertConfig"),
40+
("bigbird", "BigBirdConfig"),
41+
("bit", "BitConfig"),
42+
("blenderbot", "BlenderbotConfig"),
43+
("blenderbot_small", "BlenderbotSmallConfig"),
44+
("blip", "BlipConfig"),
45+
("blip2", "Blip2Config"),
46+
("bloom", "BloomConfig"),
47+
("chatglm", "ChatGLMConfig"),
48+
("chatglm_v2", "ChatGLMv2Config"),
49+
("chinesebert", "ChineseBertConfig"),
50+
("chineseclip", "ChineseCLIPConfig"),
51+
("clap", "ClapConfig"),
52+
("clip", "CLIPConfig"),
53+
("codegen", "CodeGenConfig"),
54+
("convbert", "ConvBertConfig"),
55+
("ctrl", "CTRLConfig"),
56+
("dallebart", "DalleBartConfig"),
57+
("deberta", "DebertaConfig"),
58+
("debertav2", "DebertaV2Config"),
59+
("distilbert", "DistilBertConfig"),
60+
("dpt", "DPTConfig"),
61+
("electra", "ElectraConfig"),
62+
("ernie", "ErnieConfig"),
63+
("ernie_code", "ErnieCodeConfig"),
64+
("ernie_ctm", "ErnieCtmConfig"),
65+
("ernie_doc", "ErnieDocConfig"),
66+
("ernie_gram", "ErnieGramConfig"),
67+
("ernie_layout", "ErnieLayoutConfig"),
68+
("ernie_m", "ErnieMConfig"),
69+
("ernie_vil", "ErnieViLConfig"),
70+
("fnet", "FNetConfig"),
71+
("funnel", "FunnelConfig"),
72+
("gau_alpha", "GAUAlphaConfig"),
73+
("gemma", "GemmaConfig"),
74+
("glm", "GLMConfig"),
75+
("gpt", "GPTConfig"),
76+
("gptj", "GPTJConfig"),
77+
("jamba", "JambaConfig"),
78+
("layoutlm", "LayoutLMConfig"),
79+
("layoutlmv2", "LayoutLMv2Config"),
80+
("layoutxlm", "LayoutXLMConfig"),
81+
("llama", "LlamaConfig"),
82+
("luke", "LukeConfig"),
83+
("mamba", "MambaConfig"),
84+
("mbart", "MBartConfig"),
85+
("megatronbert", "MegatronBertConfig"),
86+
("minigpt4", "MiniGPT4Config"),
87+
("mistral", "MistralConfig"),
88+
("mixtral", "MixtralConfig"),
89+
("mobilebert", "MobileBertConfig"),
90+
("mpnet", "MPNetConfig"),
91+
("mt5", "MT5Config"),
92+
("nezha", "NeZhaConfig"),
93+
("nystromformer", "NystromformerConfig"),
94+
("opt", "OPTConfig"),
95+
("pegasus", "PegasusConfig"),
96+
("ppminilm", "PPMiniLMConfig"),
97+
("prophetnet", "ProphetNetConfig"),
98+
("qwen", "QWenConfig"),
99+
("qwen2", "Qwen2Config"),
100+
("qwen2_moe", "Qwen2MoeConfig"),
101+
("reformer", "ReformerConfig"),
102+
("rembert", "RemBertConfig"),
103+
("roberta", "RobertaConfig"),
104+
("roformer", "RoFormerConfig"),
105+
("roformerv2", "RoFormerv2Config"),
106+
("rw", "RWConfig"),
107+
("skep", "SkepConfig"),
108+
("speecht5", "SpeechT5Config"),
109+
("squeezebert", "SqueezeBertConfig"),
110+
("t5", "T5Config"),
111+
("tinybert", "TinyBertConfig"),
112+
("unified_transformer", "UnifiedTransformerConfig"),
113+
("unimo", "UNIMOConfig"),
114+
("visualglm", "VisualGLMConfig"),
115+
("xlm", "XLMConfig"),
116+
("xlnet", "XLNetConfig"),
117+
("yuan", "YuanConfig"),
118+
]
119+
)
120+
121+
122+
MODEL_NAMES_MAPPING = OrderedDict(
123+
# Base model mapping
124+
[
125+
("albert", "Albert"),
126+
("artist", "Artist"),
127+
("bart", "Bart"),
128+
("bert", "Bert"),
129+
("bigbird", "BigBird"),
130+
("bit", "Bit"),
131+
("blenderbot", "Blenderbot"),
132+
("blenderbot_small", "BlenderbotSmall"),
133+
("blip", "Blip"),
134+
("blip2", "Blip2"),
135+
("bloom", "Bloom"),
136+
("chatglm", "ChatGLM"),
137+
("chatglm_v2", "ChatGLMv2"),
138+
("chinesebert", "ChineseBert"),
139+
("chineseclip", "ChineseCLIPText"),
140+
("clap", "CLAP"),
141+
("clip", "CLIP"),
142+
("codegen", "CodeGen"),
143+
("convbert", "ConvBert"),
144+
("ctrl", "CTRL"),
145+
("dallebart", "DalleBart"),
146+
("deberta", "Deberta"),
147+
("debertav2", "DebertaV2"),
148+
("distilbert", "DistilBert"),
149+
("dpt", "DPT"),
150+
("electra", "Electra"),
151+
("ernie", "Ernie"),
152+
("ernie_code", "ErnieCode"),
153+
("ernie_ctm", "ErnieCtm"),
154+
("ernie_doc", "ErnieDoc"),
155+
("ernie_gram", "ErnieGram"),
156+
("ernie_layout", "ErnieLayout"),
157+
("ernie_m", "ErnieM"),
158+
("ernie_vil", "ErnieViL"),
159+
("fnet", "FNet"),
160+
("funnel", "Funnel"),
161+
("gau_alpha", "GAUAlpha"),
162+
("gemma", "Gemma"),
163+
("glm", "GLM"),
164+
("gpt", "GPT"),
165+
("gptj", "GPTJ"),
166+
("jamba", "Jamba"),
167+
("layoutlm", "LayoutLM"),
168+
("layoutlmv2", "LayoutLMv2"),
169+
("layoutxlm", "LayoutXLM"),
170+
("llama", "Llama"),
171+
("luke", "Luke"),
172+
("mamba", "Mamba"),
173+
("mbart", "MBart"),
174+
("megatronbert", "MegatronBert"),
175+
("minigpt4", "MiniGPT4"),
176+
("mistral", "Mistral"),
177+
("mixtral", "Mixtral"),
178+
("mobilebert", "MobileBert"),
179+
("mpnet", "MPNet"),
180+
("mt5", "MT5"),
181+
("nezha", "NeZha"),
182+
("nystromformer", "Nystromformer"),
183+
("opt", "OPT"),
184+
("pegasus", "Pegasus"),
185+
("ppminilm", "PPMiniLM"),
186+
("prophetnet", "ProphetNet"),
187+
("qwen", "QWen"),
188+
("qwen2", "Qwen2"),
189+
("qwen2_moe", "Qwen2Moe"),
190+
("reformer", "Reformer"),
191+
("rembert", "RemBert"),
192+
("roberta", "Roberta"),
193+
("roformer", "RoFormer"),
194+
("roformerv2", "RoFormerv2"),
195+
("rw", "RW"),
196+
("skep", "Skep"),
197+
("speecht5", "SpeechT5"),
198+
("squeezebert", "SqueezeBert"),
199+
("t5", "T5"),
200+
("tinybert", "TinyBert"),
201+
("unified_transformer", "UnifiedTransformer"),
202+
("unimo", "UNIMO"),
203+
("visualglm", "VisualGLM"),
204+
("xlm", "XLM"),
205+
("xlnet", "XLNet"),
206+
("yuan", "Yuan"),
207+
]
208+
)
209+
210+
211+
def config_class_to_model_type(config):
212+
"""Converts a config class name to the corresponding model type"""
213+
for key, cls in CONFIG_MAPPING_NAMES.items():
214+
if cls == config:
215+
return key
216+
# if key not found check in extra content
217+
for key, cls in CONFIG_MAPPING._extra_content.items():
218+
if cls.__name__ == config:
219+
return key
220+
return None
221+
222+
223+
class _LazyConfigMapping(OrderedDict):
224+
"""
225+
A dictionary that lazily load its values when they are requested.
226+
"""
227+
228+
def __init__(self, mapping):
229+
self._mapping = mapping
230+
self._extra_content = {}
231+
self._modules = {}
232+
233+
def __getitem__(self, key):
234+
if key in self._extra_content:
235+
return self._extra_content[key]
236+
if key not in self._mapping:
237+
raise KeyError(key)
238+
value = self._mapping[key]
239+
module_name = model_type_to_module_name(key)
240+
if module_name not in self._modules:
241+
self._modules[module_name] = importlib.import_module(
242+
f".{module_name}.configuration", "paddlenlp.transformers"
243+
)
244+
if hasattr(self._modules[module_name], value):
245+
return getattr(self._modules[module_name], value)
246+
247+
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
248+
# object at the top level.
249+
transformers_module = importlib.import_module("paddlenlp")
250+
return getattr(transformers_module, value)
251+
252+
def keys(self):
253+
return list(self._mapping.keys()) + list(self._extra_content.keys())
254+
255+
def values(self):
256+
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
257+
258+
def items(self):
259+
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
260+
261+
def __iter__(self):
262+
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
263+
264+
def __contains__(self, item):
265+
return item in self._mapping or item in self._extra_content
266+
267+
def register(self, key, value, exist_ok=False):
268+
"""
269+
Register a new configuration in this mapping.
270+
"""
271+
if key in self._mapping.keys() and not exist_ok:
272+
raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
273+
self._extra_content[key] = value
274+
275+
276+
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
277+
33278

34279
def get_configurations() -> Dict[str, List[Type[PretrainedConfig]]]:
35280
"""load the configurations of PretrainedConfig mapping: {<model-name>: [<class-name>, <class-name>, ...], }
@@ -64,6 +309,12 @@ def get_configurations() -> Dict[str, List[Type[PretrainedConfig]]]:
64309
return mappings
65310

66311

312+
def model_type_to_module_name(key):
313+
"""Converts a config key to the corresponding module."""
314+
key = key.replace("-", "_")
315+
return key
316+
317+
67318
class AutoConfig(PretrainedConfig):
68319
"""
69320
AutoConfig is a generic config class that will be instantiated as one of the
@@ -191,12 +442,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwar
191442
from_hf_hub=from_hf_hub,
192443
from_aistudio=from_aistudio,
193444
)
194-
if config_file is not None and os.path.exists(config_file):
445+
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
446+
if "model_type" in config_dict:
447+
try:
448+
config_class = CONFIG_MAPPING[config_dict["model_type"]]
449+
except KeyError:
450+
raise ValueError(
451+
f"The checkpoint you are trying to load has model type `{config_dict['model_type']}` "
452+
"but Transformers does not recognize this architecture. This could be because of an "
453+
"issue with the checkpoint, or because your version of Transformers is out of date."
454+
)
455+
return config_class.from_dict(config_dict, **unused_kwargs)
456+
elif "model_type" not in config_dict and config_file is not None and os.path.exists(config_file):
195457
config_class = cls._get_config_class_from_config(pretrained_model_name_or_path, config_file)
196458
logger.info("We are using %s to load '%s'." % (config_class, pretrained_model_name_or_path))
197459
if config_class is cls:
198460
return cls.from_file(config_file)
199461
return config_class.from_pretrained(config_file, *model_args, **kwargs)
462+
elif config_file is None:
463+
# Fallback: use pattern matching on the string.
464+
# We go from longer names to shorter names to catch roberta before bert (for instance)
465+
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
466+
if pattern in str(pretrained_model_name_or_path):
467+
return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
200468
else:
201469
raise RuntimeError(
202470
f"Can't load config for '{pretrained_model_name_or_path}'.\n"
@@ -205,3 +473,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwar
205473
"- or a correct model-identifier of community-contributed pretrained models,\n"
206474
"- or the correct path to a directory containing relevant config files.\n"
207475
)
476+
477+
@staticmethod
478+
def register(model_type, config, exist_ok=False):
479+
"""
480+
Register a new configuration for this class.
481+
482+
Args:
483+
model_type (`str`): The model type like "bert" or "gpt".
484+
config ([`PretrainedConfig`]): The config to register.
485+
"""
486+
if issubclass(config, PretrainedConfig) and config.model_type != model_type:
487+
raise ValueError(
488+
"The config you are passing has a `model_type` attribute that is not consistent with the model type "
489+
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
490+
"match!"
491+
)
492+
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)

0 commit comments

Comments
 (0)