@@ -66,63 +66,64 @@ def run_torch_compile(model, backend='openvino', dynamic=None, options=None, chi
66
66
67
67
def create_text_gen_model (model_path , device , memory_data_collector , ** kwargs ):
68
68
model_path = Path (model_path )
69
- from_pretrain_time = 0
70
- if model_path .exists ():
71
- if model_path .is_dir () and len (os .listdir (model_path )) != 0 :
72
- log .info (f'Load text model from model path:{ model_path } ' )
73
- default_model_type = DEFAULT_MODEL_CLASSES [kwargs ['use_case' ]]
74
- model_type = kwargs .get ('model_type' , default_model_type )
75
- model_class = PT_MODEL_CLASSES_MAPPING .get (model_type , PT_MODEL_CLASSES_MAPPING [default_model_type ])
76
- token_class = TOKENIZE_CLASSES_MAPPING .get (model_type , TOKENIZE_CLASSES_MAPPING [default_model_type ])
77
- if kwargs .get ("mem_consumption" ):
78
- memory_data_collector .start ()
79
- start = time .perf_counter ()
80
- trust_remote_code = False
81
- try :
82
- model = model_class .from_pretrained (model_path , trust_remote_code = trust_remote_code )
83
- except Exception :
84
- start = time .perf_counter ()
85
- trust_remote_code = True
86
- model = model_class .from_pretrained (model_path , trust_remote_code = trust_remote_code )
87
- tokenizer = token_class .from_pretrained (model_path , trust_remote_code = trust_remote_code )
88
- end = time .perf_counter ()
89
- from_pretrain_time = end - start
90
- if kwargs .get ("mem_consumption" ):
91
- memory_data_collector .stop_and_collect_data ('from_pretrained_phase' )
92
- memory_data_collector .log_data (compilation_phase = True )
93
- else :
94
- raise RuntimeError (f'==Failure ==: model path:{ model_path } is not directory or directory is empty' )
95
- else :
69
+ is_gguf_model = model_path .suffix in '.gguf'
70
+ if not model_path .exists ():
96
71
raise RuntimeError (f'==Failure ==: model path:{ model_path } is not exist' )
72
+ if not is_gguf_model and not (model_path .is_dir () and len (os .listdir (model_path )) != 0 ):
73
+ raise RuntimeError (f'==Failure ==: model path:{ model_path } is not directory or directory is empty' )
74
+ if not device :
75
+ raise RuntimeError ('==Failure ==: no device to load' )
76
+
77
+ log .info (f'Load text model from model path:{ model_path } ' )
78
+ default_model_type = DEFAULT_MODEL_CLASSES [kwargs ['use_case' ]]
79
+ model_type = kwargs .get ('model_type' , default_model_type )
80
+ model_class = PT_MODEL_CLASSES_MAPPING .get (model_type , PT_MODEL_CLASSES_MAPPING [default_model_type ])
81
+ token_class = TOKENIZE_CLASSES_MAPPING .get (model_type , TOKENIZE_CLASSES_MAPPING [default_model_type ])
82
+ if kwargs .get ("mem_consumption" ):
83
+ memory_data_collector .start ()
84
+ start = time .perf_counter ()
85
+ load_model_kwargs = {'trust_remote_code' : False }
86
+ if is_gguf_model :
87
+ load_model_kwargs |= {'gguf_file' : str (model_path )}
88
+ model_path = model_path .parent
89
+ try :
90
+ model = model_class .from_pretrained (model_path , ** load_model_kwargs )
91
+ except Exception :
92
+ start = time .perf_counter ()
93
+ load_model_kwargs ['trust_remote_code' ] = True
94
+ model = model_class .from_pretrained (model_path , ** load_model_kwargs )
95
+ tokenizer = token_class .from_pretrained (model_path , ** load_model_kwargs )
96
+ end = time .perf_counter ()
97
+ from_pretrain_time = end - start
98
+ if kwargs .get ("mem_consumption" ):
99
+ memory_data_collector .stop_and_collect_data ('from_pretrained_phase' )
100
+ memory_data_collector .log_data (compilation_phase = True )
97
101
98
102
log .info (f'model path:{ model_path } , from pretrained time: { from_pretrain_time :.2f} s' )
99
103
100
- if device is not None :
101
- gptjfclm = 'transformers.models.gptj.modeling_gptj.GPTJForCausalLM'
102
- lfclm = 'transformers.models.llama.modeling_llama.LlamaForCausalLM'
103
- bfclm = 'transformers.models.bloom.modeling_bloom.BloomForCausalLM'
104
- gpt2lmhm = 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'
105
- gptneoxclm = 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM'
106
- chatglmfcg = 'transformers_modules.pytorch_original.modeling_chatglm.ChatGLMForConditionalGeneration'
107
- real_base_model_name = str (type (model )).lower ()
108
- log .info (f'Real base model={ real_base_model_name } ' )
109
- # bfclm will trigger generate crash.
104
+ gptjfclm = 'transformers.models.gptj.modeling_gptj.GPTJForCausalLM'
105
+ lfclm = 'transformers.models.llama.modeling_llama.LlamaForCausalLM'
106
+ bfclm = 'transformers.models.bloom.modeling_bloom.BloomForCausalLM'
107
+ gpt2lmhm = 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'
108
+ gptneoxclm = 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM'
109
+ chatglmfcg = 'transformers_modules.pytorch_original.modeling_chatglm.ChatGLMForConditionalGeneration'
110
+ real_base_model_name = str (type (model )).lower ()
111
+ log .info (f'Real base model={ real_base_model_name } ' )
112
+ # bfclm will trigger generate crash.
110
113
111
- # If the device is set to GPU there's a need to substitute it with 'cuda' so it will be accepted by PyTorch
112
- if device .upper () == 'GPU' :
113
- device = torch .device ('cuda' ) if torch .cuda .is_available () else log .info ('CUDA device is unavailable' )
114
- else :
115
- device = torch .device (device .lower ())
116
- log .info (f'Torch device was set to: { device } ' )
114
+ # If the device is set to GPU there's a need to substitute it with 'cuda' so it will be accepted by PyTorch
115
+ if device .upper () == 'GPU' :
116
+ device = torch .device ('cuda' ) if torch .cuda .is_available () else log .info ('CUDA device is unavailable' )
117
+ else :
118
+ device = torch .device (device .lower ())
119
+ log .info (f'Torch device was set to: { device } ' )
117
120
118
- if any (x in real_base_model_name for x in [gptjfclm , lfclm , bfclm , gpt2lmhm , gptneoxclm , chatglmfcg ]):
119
- model = set_bf16 (model , device , ** kwargs )
120
- else :
121
- if len (kwargs ['config' ]) > 0 and kwargs ['config' ].get ('PREC_BF16' ) and kwargs ['config' ]['PREC_BF16' ] is True :
122
- log .info ('Param [bf16/prec_bf16] will not work.' )
123
- model .to (device )
121
+ if any (x in real_base_model_name for x in [gptjfclm , lfclm , bfclm , gpt2lmhm , gptneoxclm , chatglmfcg ]):
122
+ model = set_bf16 (model , device , ** kwargs )
124
123
else :
125
- raise RuntimeError ('==Failure ==: no device to load' )
124
+ if len (kwargs ['config' ]) > 0 and kwargs ['config' ].get ('PREC_BF16' ) and kwargs ['config' ]['PREC_BF16' ] is True :
125
+ log .info ('Param [bf16/prec_bf16] will not work.' )
126
+ model .to (device )
126
127
127
128
bench_hook = hook_common .get_bench_hook (kwargs ['num_beams' ], model )
128
129
0 commit comments