@@ -36,7 +36,7 @@ def __init__(
36
36
37
37
self .pruned_model = self .prune_model_input ()
38
38
self .draft_model = None
39
- if getattr (args , 'model_draft' , None ) :
39
+ if hasattr (args , 'model_draft' ) and args . model_draft :
40
40
dm_args = copy .deepcopy (args )
41
41
dm_args .model_draft = None
42
42
self .draft_model = ModelFactory (args .model_draft , dm_args , ignore_stderr = True ).create ()
@@ -63,7 +63,7 @@ def __init__(
63
63
def detect_model_model_type (
64
64
self ,
65
65
) -> Tuple [type [Union [Huggingface , Ollama , OCI , URL ]], Callable [[], Union [Huggingface , Ollama , OCI , URL ]]]:
66
- for prefix in ["huggingface://" , "hf ://" , "hf.co/" ]:
66
+ for prefix in ["huggingface://" , "hfq ://" , "hf.co/" ]:
67
67
if self .model .startswith (prefix ):
68
68
return Huggingface , self .create_huggingface
69
69
for prefix in ["modelscope://" , "ms://" ]:
@@ -137,7 +137,7 @@ def create_oci(self) -> OCI:
137
137
def create_url (self ) -> URL :
138
138
model = URL (self .pruned_model , self .store_path , urlparse (self .model ).scheme )
139
139
model .draft_model = self .draft_model
140
- if getattr (self , 'split_model' , None ):
140
+ if hasattr (self , 'split_model' ):
141
141
model .split_model = self .split_model
142
142
model .mnt_path = self .mnt_path
143
143
return model
0 commit comments