1313# limitations under the License.
1414from __future__ import annotations
1515
16+ import importlib
1617import inspect
1718import io
1819import json
1920import os
20- from collections import defaultdict
21+ from collections import OrderedDict , defaultdict
2122from typing import Dict , List , Type
2223
2324from ...utils .download import resolve_file_path
3031 "AutoConfig" ,
3132]
3233
34+ CONFIG_MAPPING_NAMES = OrderedDict (
35+ [
36+ ("albert" , "AlbertConfig" ),
37+ ("artist" , "ArtistConfig" ),
38+ ("bart" , "BartConfig" ),
39+ ("bert" , "BertConfig" ),
40+ ("bigbird" , "BigBirdConfig" ),
41+ ("bit" , "BitConfig" ),
42+ ("blenderbot" , "BlenderbotConfig" ),
43+ ("blenderbot_small" , "BlenderbotSmallConfig" ),
44+ ("blip" , "BlipConfig" ),
45+ ("blip2" , "Blip2Config" ),
46+ ("bloom" , "BloomConfig" ),
47+ ("chatglm" , "ChatGLMConfig" ),
48+ ("chatglm_v2" , "ChatGLMv2Config" ),
49+ ("chinesebert" , "ChineseBertConfig" ),
50+ ("chineseclip" , "ChineseCLIPConfig" ),
51+ ("clap" , "ClapConfig" ),
52+ ("clip" , "CLIPConfig" ),
53+ ("codegen" , "CodeGenConfig" ),
54+ ("convbert" , "ConvBertConfig" ),
55+ ("ctrl" , "CTRLConfig" ),
56+ ("dallebart" , "DalleBartConfig" ),
57+ ("deberta" , "DebertaConfig" ),
58+ ("debertav2" , "DebertaV2Config" ),
59+ ("distilbert" , "DistilBertConfig" ),
60+ ("dpt" , "DPTConfig" ),
61+ ("electra" , "ElectraConfig" ),
62+ ("ernie" , "ErnieConfig" ),
63+ ("ernie_code" , "ErnieCodeConfig" ),
64+ ("ernie_ctm" , "ErnieCtmConfig" ),
65+ ("ernie_doc" , "ErnieDocConfig" ),
66+ ("ernie_gram" , "ErnieGramConfig" ),
67+ ("ernie_layout" , "ErnieLayoutConfig" ),
68+ ("ernie_m" , "ErnieMConfig" ),
69+ ("ernie_vil" , "ErnieViLConfig" ),
70+ ("fnet" , "FNetConfig" ),
71+ ("funnel" , "FunnelConfig" ),
72+ ("gau_alpha" , "GAUAlphaConfig" ),
73+ ("gemma" , "GemmaConfig" ),
74+ ("glm" , "GLMConfig" ),
75+ ("gpt" , "GPTConfig" ),
76+ ("gptj" , "GPTJConfig" ),
77+ ("jamba" , "JambaConfig" ),
78+ ("layoutlm" , "LayoutLMConfig" ),
79+ ("layoutlmv2" , "LayoutLMv2Config" ),
80+ ("layoutxlm" , "LayoutXLMConfig" ),
81+ ("llama" , "LlamaConfig" ),
82+ ("luke" , "LukeConfig" ),
83+ ("mamba" , "MambaConfig" ),
84+ ("mbart" , "MBartConfig" ),
85+ ("megatronbert" , "MegatronBertConfig" ),
86+ ("minigpt4" , "MiniGPT4Config" ),
87+ ("mistral" , "MistralConfig" ),
88+ ("mixtral" , "MixtralConfig" ),
89+ ("mobilebert" , "MobileBertConfig" ),
90+ ("mpnet" , "MPNetConfig" ),
91+ ("mt5" , "MT5Config" ),
92+ ("nezha" , "NeZhaConfig" ),
93+ ("nystromformer" , "NystromformerConfig" ),
94+ ("opt" , "OPTConfig" ),
95+ ("pegasus" , "PegasusConfig" ),
96+ ("ppminilm" , "PPMiniLMConfig" ),
97+ ("prophetnet" , "ProphetNetConfig" ),
98+ ("qwen" , "QWenConfig" ),
99+ ("qwen2" , "Qwen2Config" ),
100+ ("qwen2_moe" , "Qwen2MoeConfig" ),
101+ ("reformer" , "ReformerConfig" ),
102+ ("rembert" , "RemBertConfig" ),
103+ ("roberta" , "RobertaConfig" ),
104+ ("roformer" , "RoFormerConfig" ),
105+ ("roformerv2" , "RoFormerv2Config" ),
106+ ("rw" , "RWConfig" ),
107+ ("skep" , "SkepConfig" ),
108+ ("speecht5" , "SpeechT5Config" ),
109+ ("squeezebert" , "SqueezeBertConfig" ),
110+ ("t5" , "T5Config" ),
111+ ("tinybert" , "TinyBertConfig" ),
112+ ("unified_transformer" , "UnifiedTransformerConfig" ),
113+ ("unimo" , "UNIMOConfig" ),
114+ ("visualglm" , "VisualGLMConfig" ),
115+ ("xlm" , "XLMConfig" ),
116+ ("xlnet" , "XLNetConfig" ),
117+ ("yuan" , "YuanConfig" ),
118+ ]
119+ )
120+
121+
122+ MODEL_NAMES_MAPPING = OrderedDict (
123+ # Base model mapping
124+ [
125+ ("albert" , "Albert" ),
126+ ("artist" , "Artist" ),
127+ ("bart" , "Bart" ),
128+ ("bert" , "Bert" ),
129+ ("bigbird" , "BigBird" ),
130+ ("bit" , "Bit" ),
131+ ("blenderbot" , "Blenderbot" ),
132+ ("blenderbot_small" , "BlenderbotSmall" ),
133+ ("blip" , "Blip" ),
134+ ("blip2" , "Blip2" ),
135+ ("bloom" , "Bloom" ),
136+ ("chatglm" , "ChatGLM" ),
137+ ("chatglm_v2" , "ChatGLMv2" ),
138+ ("chinesebert" , "ChineseBert" ),
139+ ("chineseclip" , "ChineseCLIPText" ),
140+ ("clap" , "CLAP" ),
141+ ("clip" , "CLIP" ),
142+ ("codegen" , "CodeGen" ),
143+ ("convbert" , "ConvBert" ),
144+ ("ctrl" , "CTRL" ),
145+ ("dallebart" , "DalleBart" ),
146+ ("deberta" , "Deberta" ),
147+ ("debertav2" , "DebertaV2" ),
148+ ("distilbert" , "DistilBert" ),
149+ ("dpt" , "DPT" ),
150+ ("electra" , "Electra" ),
151+ ("ernie" , "Ernie" ),
152+ ("ernie_code" , "ErnieCode" ),
153+ ("ernie_ctm" , "ErnieCtm" ),
154+ ("ernie_doc" , "ErnieDoc" ),
155+ ("ernie_gram" , "ErnieGram" ),
156+ ("ernie_layout" , "ErnieLayout" ),
157+ ("ernie_m" , "ErnieM" ),
158+ ("ernie_vil" , "ErnieViL" ),
159+ ("fnet" , "FNet" ),
160+ ("funnel" , "Funnel" ),
161+ ("gau_alpha" , "GAUAlpha" ),
162+ ("gemma" , "Gemma" ),
163+ ("glm" , "GLM" ),
164+ ("gpt" , "GPT" ),
165+ ("gptj" , "GPTJ" ),
166+ ("jamba" , "Jamba" ),
167+ ("layoutlm" , "LayoutLM" ),
168+ ("layoutlmv2" , "LayoutLMv2" ),
169+ ("layoutxlm" , "LayoutXLM" ),
170+ ("llama" , "Llama" ),
171+ ("luke" , "Luke" ),
172+ ("mamba" , "Mamba" ),
173+ ("mbart" , "MBart" ),
174+ ("megatronbert" , "MegatronBert" ),
175+ ("minigpt4" , "MiniGPT4" ),
176+ ("mistral" , "Mistral" ),
177+ ("mixtral" , "Mixtral" ),
178+ ("mobilebert" , "MobileBert" ),
179+ ("mpnet" , "MPNet" ),
180+ ("mt5" , "MT5" ),
181+ ("nezha" , "NeZha" ),
182+ ("nystromformer" , "Nystromformer" ),
183+ ("opt" , "OPT" ),
184+ ("pegasus" , "Pegasus" ),
185+ ("ppminilm" , "PPMiniLM" ),
186+ ("prophetnet" , "ProphetNet" ),
187+ ("qwen" , "QWen" ),
188+ ("qwen2" , "Qwen2" ),
189+ ("qwen2_moe" , "Qwen2Moe" ),
190+ ("reformer" , "Reformer" ),
191+ ("rembert" , "RemBert" ),
192+ ("roberta" , "Roberta" ),
193+ ("roformer" , "RoFormer" ),
194+ ("roformerv2" , "RoFormerv2" ),
195+ ("rw" , "RW" ),
196+ ("skep" , "Skep" ),
197+ ("speecht5" , "SpeechT5" ),
198+ ("squeezebert" , "SqueezeBert" ),
199+ ("t5" , "T5" ),
200+ ("tinybert" , "TinyBert" ),
201+ ("unified_transformer" , "UnifiedTransformer" ),
202+ ("unimo" , "UNIMO" ),
203+ ("visualglm" , "VisualGLM" ),
204+ ("xlm" , "XLM" ),
205+ ("xlnet" , "XLNet" ),
206+ ("yuan" , "Yuan" ),
207+ ]
208+ )
209+
210+
211+ def config_class_to_model_type (config ):
212+ """Converts a config class name to the corresponding model type"""
213+ for key , cls in CONFIG_MAPPING_NAMES .items ():
214+ if cls == config :
215+ return key
216+ # if key not found check in extra content
217+ for key , cls in CONFIG_MAPPING ._extra_content .items ():
218+ if cls .__name__ == config :
219+ return key
220+ return None
221+
222+
223+ class _LazyConfigMapping (OrderedDict ):
224+ """
225+ A dictionary that lazily load its values when they are requested.
226+ """
227+
228+ def __init__ (self , mapping ):
229+ self ._mapping = mapping
230+ self ._extra_content = {}
231+ self ._modules = {}
232+
233+ def __getitem__ (self , key ):
234+ if key in self ._extra_content :
235+ return self ._extra_content [key ]
236+ if key not in self ._mapping :
237+ raise KeyError (key )
238+ value = self ._mapping [key ]
239+ module_name = model_type_to_module_name (key )
240+ if module_name not in self ._modules :
241+ self ._modules [module_name ] = importlib .import_module (
242+ f".{ module_name } .configuration" , "paddlenlp.transformers"
243+ )
244+ if hasattr (self ._modules [module_name ], value ):
245+ return getattr (self ._modules [module_name ], value )
246+
247+ # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
248+ # object at the top level.
249+ transformers_module = importlib .import_module ("paddlenlp" )
250+ return getattr (transformers_module , value )
251+
252+ def keys (self ):
253+ return list (self ._mapping .keys ()) + list (self ._extra_content .keys ())
254+
255+ def values (self ):
256+ return [self [k ] for k in self ._mapping .keys ()] + list (self ._extra_content .values ())
257+
258+ def items (self ):
259+ return [(k , self [k ]) for k in self ._mapping .keys ()] + list (self ._extra_content .items ())
260+
261+ def __iter__ (self ):
262+ return iter (list (self ._mapping .keys ()) + list (self ._extra_content .keys ()))
263+
264+ def __contains__ (self , item ):
265+ return item in self ._mapping or item in self ._extra_content
266+
267+ def register (self , key , value , exist_ok = False ):
268+ """
269+ Register a new configuration in this mapping.
270+ """
271+ if key in self ._mapping .keys () and not exist_ok :
272+ raise ValueError (f"'{ key } ' is already used by a Transformers config, pick another name." )
273+ self ._extra_content [key ] = value
274+
275+
276+ CONFIG_MAPPING = _LazyConfigMapping (CONFIG_MAPPING_NAMES )
277+
33278
34279def get_configurations () -> Dict [str , List [Type [PretrainedConfig ]]]:
35280 """load the configurations of PretrainedConfig mapping: {<model-name>: [<class-name>, <class-name>, ...], }
@@ -64,6 +309,12 @@ def get_configurations() -> Dict[str, List[Type[PretrainedConfig]]]:
64309 return mappings
65310
66311
312+ def model_type_to_module_name (key ):
313+ """Converts a config key to the corresponding module."""
314+ key = key .replace ("-" , "_" )
315+ return key
316+
317+
67318class AutoConfig (PretrainedConfig ):
68319 """
69320 AutoConfig is a generic config class that will be instantiated as one of the
@@ -191,12 +442,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwar
191442 from_hf_hub = from_hf_hub ,
192443 from_aistudio = from_aistudio ,
193444 )
194- if config_file is not None and os .path .exists (config_file ):
445+ config_dict , unused_kwargs = PretrainedConfig .get_config_dict (pretrained_model_name_or_path , ** kwargs )
446+ if "model_type" in config_dict :
447+ try :
448+ config_class = CONFIG_MAPPING [config_dict ["model_type" ]]
449+ except KeyError :
450+ raise ValueError (
451+ f"The checkpoint you are trying to load has model type `{ config_dict ['model_type' ]} ` "
452+ "but Transformers does not recognize this architecture. This could be because of an "
453+ "issue with the checkpoint, or because your version of Transformers is out of date."
454+ )
455+ return config_class .from_dict (config_dict , ** unused_kwargs )
456+ elif "model_type" not in config_dict and config_file is not None and os .path .exists (config_file ):
195457 config_class = cls ._get_config_class_from_config (pretrained_model_name_or_path , config_file )
196458 logger .info ("We are using %s to load '%s'." % (config_class , pretrained_model_name_or_path ))
197459 if config_class is cls :
198460 return cls .from_file (config_file )
199461 return config_class .from_pretrained (config_file , * model_args , ** kwargs )
462+ elif config_file is None :
463+ # Fallback: use pattern matching on the string.
464+ # We go from longer names to shorter names to catch roberta before bert (for instance)
465+ for pattern in sorted (CONFIG_MAPPING .keys (), key = len , reverse = True ):
466+ if pattern in str (pretrained_model_name_or_path ):
467+ return CONFIG_MAPPING [pattern ].from_dict (config_dict , ** unused_kwargs )
200468 else :
201469 raise RuntimeError (
202470 f"Can't load config for '{ pretrained_model_name_or_path } '.\n "
@@ -205,3 +473,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwar
205473 "- or a correct model-identifier of community-contributed pretrained models,\n "
206474 "- or the correct path to a directory containing relevant config files.\n "
207475 )
476+
477+ @staticmethod
478+ def register (model_type , config , exist_ok = False ):
479+ """
480+ Register a new configuration for this class.
481+
482+ Args:
483+ model_type (`str`): The model type like "bert" or "gpt".
484+ config ([`PretrainedConfig`]): The config to register.
485+ """
486+ if issubclass (config , PretrainedConfig ) and config .model_type != model_type :
487+ raise ValueError (
488+ "The config you are passing has a `model_type` attribute that is not consistent with the model type "
489+ f"you passed (config has { config .model_type } and you passed { model_type } . Fix one of those so they "
490+ "match!"
491+ )
492+ CONFIG_MAPPING .register (model_type , config , exist_ok = exist_ok )
0 commit comments