1313# limitations under the License.
1414from __future__ import annotations
1515
16- import json
1716import os
18- from functools import partial
1917
20- import numpy as np
2118import paddle
22- from tqdm import tqdm
2319
24- from paddlenlp .transformers import AutoConfig
2520from paddlenlp .transformers .model_utils import (
26- _add_variant ,
2721 dtype_guard ,
28- load_state_dict ,
22+ load_tp_checkpoint ,
2923 no_init_weights ,
3024)
3125from paddlenlp .transformers .utils import (
3226 ContextManagers ,
3327 is_paddle_support_lazy_init ,
3428 is_safetensors_available ,
35- paddlenlp_load ,
3629)
37- from paddlenlp .utils .env import (
38- PADDLE_WEIGHTS_INDEX_NAME ,
39- SAFE_MASTER_WEIGHTS_INDEX_NAME ,
40- SAFE_PEFT_WEIGHTS_INDEX_NAME ,
41- SAFE_WEIGHTS_INDEX_NAME ,
42- )
43-
44- try :
45- from paddlenlp .utils .safetensors import fast_load_file as safe_load_file
46- from paddlenlp .utils .safetensors import fast_safe_open as safe_open
47- except :
48- from safetensors import safe_open
49- from safetensors .numpy import load_file as safe_load_file
50-
51-
52- def load_sharded_checkpoint (folder , variant = None , return_numpy = False ):
53- """
54-
55- This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
56- loaded in the model.
57-
58- Args:
59- folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
60- variant (`str`): The model variant.
61- return_numpy (`bool`): Whether to return numpy array instead of paddle tensor.
62-
63- """
64- # Load the index
65- pdparams_file = os .path .join (folder , _add_variant ("model_state.pdparams" , variant ))
66- lora_pdparams_file = os .path .join (folder , _add_variant ("lora_model_state.pdparams" , variant ))
67- safetensors_file = os .path .join (folder , _add_variant ("model.safetensors" , variant ))
68- if os .path .isfile (pdparams_file ):
69- return paddle .load (pdparams_file , return_numpy = return_numpy )
70- if os .path .isfile (lora_pdparams_file ):
71- return paddle .load (lora_pdparams_file , return_numpy = return_numpy )
72- if os .path .isfile (safetensors_file ):
73- state_dict = safe_load_file (safetensors_file )
74- if not return_numpy :
75- for key in list (state_dict .keys ()):
76- if isinstance (state_dict [key ], np .ndarray ):
77- state_dict [key ] = paddle .Tensor (state_dict .pop (key ), zero_copy = True )
78- return state_dict
79-
80- index_file = os .path .join (folder , _add_variant (PADDLE_WEIGHTS_INDEX_NAME , variant ))
81- safe_index_file = os .path .join (folder , _add_variant (SAFE_WEIGHTS_INDEX_NAME , variant ))
82- safe_master_file = os .path .join (folder , _add_variant (SAFE_MASTER_WEIGHTS_INDEX_NAME , variant ))
83- safe_peft_file = os .path .join (folder , _add_variant (SAFE_PEFT_WEIGHTS_INDEX_NAME , variant ))
84-
85- index_present = os .path .isfile (index_file )
86- safe_index_present = os .path .isfile (safe_index_file )
87- safe_master_present = os .path .isfile (safe_master_file )
88- safe_peft_present = os .path .isfile (safe_peft_file )
89-
90- load_safe = False
91- load_index = None
92- if safe_index_present :
93- load_safe = True # load safe due to preference
94- load_index = safe_index_file
95- elif safe_master_present :
96- load_safe = True
97- load_index = safe_master_file
98- elif index_present :
99- load_index = index_file
100- elif safe_peft_present :
101- load_safe = True
102- load_index = safe_peft_file
103- else :
104- raise ValueError (f"Could not find { index_file } or { safe_index_file } or { safe_peft_file } " )
105-
106- with open (load_index , "r" , encoding = "utf-8" ) as f :
107- index = json .load (f )
108-
109- shard_files = list (set (index ["weight_map" ].values ()))
110- loader = safe_load_file if load_safe else partial (paddlenlp_load , map_location = "np" if return_numpy else "cpu" )
111-
112- ret = {}
113- for shard_file in tqdm (shard_files ):
114- state_dict = loader (os .path .join (folder , shard_file ))
115- ret .update (state_dict )
116-
117- if not return_numpy :
118- for key in list (ret .keys ()):
119- if isinstance (ret [key ], np .ndarray ):
120- ret [key ] = paddle .Tensor (ret .pop (key ), zero_copy = True )
121-
122- return ret
123-
124-
125- def load_tp_checkpoint (folder , cls , config , return_numpy = False ):
126- """
127-
128- This load is performed efficiently: Load tp checkpoint only from cpu, no need to init the model.
129-
130- Args:
131- folder (`str` or `os.PathLike`): A path to a folder containing the model checkpoint.
132- cls (`str`): The model class.
133- config (`AutoConfig`): The model config.
134- return_numpy (bool): Whether load the tp checkpoint as numpy.
135- """
136-
137- config = AutoConfig .from_pretrained (folder )
138- if config .tensor_parallel_degree == 1 or config .tensor_parallel_degree == - 1 :
139- return load_sharded_checkpoint (folder , return_numpy = return_numpy )
140- else :
141- rank_model_path = os .path .join (folder , f"model_state.tp0{ config .tensor_parallel_rank } .pdparams" )
142- model_path = os .path .join (folder , "model_state.pdparams" )
143- safe_model_path = os .path .join (folder , "model.safetensors" )
144- if os .path .exists (rank_model_path ):
145- return paddle .load (rank_model_path , return_numpy = return_numpy )
146- elif os .path .exists (model_path ):
147- state_dict = cls .convert_tensor_parallel (model_path , config )
148- elif os .path .exists (safe_model_path ):
149- with safe_open (safe_model_path , framework = "np" , device = "cpu" ) as f :
150- loaded_keys = f .keys ()
151- tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_keys )
152- state_dict = load_state_dict (safe_model_path , tp_actions )
153- else : # shard files safetensors
154- resolved_archive_file , resolved_sharded_files , sharded_metadata , is_sharded = cls ._resolve_model_file_path (
155- pretrained_model_name_or_path = folder ,
156- use_safetensors = True ,
157- )
158- if len (resolved_sharded_files ) > 1 :
159- resolved_sharded_files = tqdm (resolved_sharded_files , desc = "Loading checkpoint shards" )
160- loaded_state_dict_keys = sharded_metadata ["all_checkpoint_keys" ]
161- tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_state_dict_keys , ignore_error = True )
162- state_dict = {}
163- for shard_file in resolved_sharded_files :
164- shard_state_dict = load_state_dict (
165- shard_file ,
166- tp_actions ,
167- loaded_state_dict_keys ,
168- )
169- state_dict .update (shard_state_dict )
170- if return_numpy :
171- for k in list (state_dict .keys ()):
172- if not isinstance (state_dict [k ], np .ndarray ):
173- state_dict [k ] = state_dict .pop (k ).cpu ().numpy ()
174- return state_dict
17530
17631
177- def infererence_model_from_pretrained (cls , pretrained_model_name_or_path , args , kwargs ):
32+ def infererence_model_from_pretrained (cls , pretrained_model_name_or_path , args , kwargs , return_numpy = True ):
17833 r"""
17934 Instantiate a pretrained model configuration from a pre-trained model name or path.
18035 """
@@ -203,7 +58,7 @@ def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args,
20358 with ContextManagers (init_contexts ):
20459 model = cls (config )
20560
206- resolved_archive_file , resolved_sharded_files , sharded_metadata , is_sharded = cls ._resolve_model_file_path (
61+ resolved_archive_file , _ , _ , _ = cls ._resolve_model_file_path (
20762 pretrained_model_name_or_path ,
20863 cache_dir = cache_dir ,
20964 subfolder = subfolder ,
@@ -216,7 +71,7 @@ def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args,
21671 )
21772
21873 model_path = os .path .dirname (resolved_archive_file )
219- state_dict = load_tp_checkpoint (model_path , cls , config , return_numpy = True )
74+ state_dict = load_tp_checkpoint (model_path , cls , config , return_numpy = return_numpy )
22075 model .set_state_dict (state_dict )
22176
22277 return model
0 commit comments