@@ -278,3 +278,65 @@ def get_processor_type_from_user_config(user_processor_type: Optional[Union[str,
278278 else :
279279 raise NotImplementedError (f"Unsupported processor type: { user_processor_type } " )
280280 return processor_type
281+
282+
283+ def dowload_hf_model (repo_id , cache_dir = None , repo_type = None , revision = None ):
284+ """Download hugging face model from hf hub."""
285+ import os
286+
287+ from huggingface_hub .constants import DEFAULT_REVISION , HUGGINGFACE_HUB_CACHE
288+ from huggingface_hub .file_download import REGEX_COMMIT_HASH , repo_folder_name
289+ from huggingface_hub .utils import EntryNotFoundError
290+
291+ if cache_dir is None :
292+ cache_dir = HUGGINGFACE_HUB_CACHE
293+ if revision is None :
294+ revision = DEFAULT_REVISION
295+ if repo_type is None :
296+ repo_type = "model"
297+ storage_folder = os .path .join (cache_dir , repo_folder_name (repo_id = repo_id , repo_type = repo_type ))
298+ commit_hash = None
299+ if REGEX_COMMIT_HASH .match (revision ):
300+ commit_hash = revision
301+ else :
302+ ref_path = os .path .join (storage_folder , "refs" , revision )
303+ if os .path .exists (ref_path ):
304+ with open (ref_path ) as f :
305+ commit_hash = f .read ()
306+ if storage_folder and commit_hash :
307+ pointer_path = os .path .join (storage_folder , "snapshots" , commit_hash )
308+ if os .path .isdir (pointer_path ):
309+ return pointer_path
310+ else : # pragma: no cover
311+ from huggingface_hub import snapshot_download
312+
313+ file_path = snapshot_download (repo_id )
314+ return file_path
315+
316+
317+ def load_empty_model (pretrained_model_name_or_path , cls = None , ** kwargs ):
318+ """Load a empty model."""
319+ import os
320+
321+ from accelerate import init_empty_weights
322+ from transformers import AutoConfig , AutoModelForCausalLM
323+ from transformers .models .auto .auto_factory import _BaseAutoModelClass
324+
325+ cls = AutoModelForCausalLM if cls is None else cls
326+ is_local = os .path .isdir (pretrained_model_name_or_path )
327+ if is_local : # pragma: no cover
328+ path = pretrained_model_name_or_path
329+ else :
330+ path = dowload_hf_model (pretrained_model_name_or_path )
331+ if cls .__base__ == _BaseAutoModelClass :
332+ config = AutoConfig .from_pretrained (path , ** kwargs )
333+ with init_empty_weights ():
334+ model = cls .from_config (config )
335+ else : # pragma: no cover
336+ config = cls .config_class .from_pretrained (path , ** kwargs )
337+ with init_empty_weights ():
338+ model = cls (config )
339+ model .tie_weights ()
340+ model .eval ()
341+ model .path = pretrained_model_name_or_path
342+ return model
0 commit comments