@@ -290,14 +290,39 @@ def forge_loader(sd, additional_state_dicts=None):
290
290
if component is not None :
291
291
huggingface_components [component_name ] = component
292
292
293
- # Fix Huggingface prediction type using estimated config detection
293
+ yaml_config = None
294
+ yaml_config_prediction_type = None
295
+
296
+ try :
297
+ import yaml
298
+ from pathlib import Path
299
+ config_filename = os .path .splitext (sd )[0 ] + '.yaml'
300
+ if Path (config_filename ).is_file ():
301
+ with open (config_filename , 'r' ) as stream :
302
+ yaml_config = yaml .safe_load (stream )
303
+ except ImportError :
304
+ pass
305
+
306
+ # Fix Huggingface prediction type using .yaml config or estimated config detection
294
307
prediction_types = {
295
308
'EPS' : 'epsilon' ,
296
309
'V_PREDICTION' : 'v_prediction' ,
297
310
'EDM' : 'edm' ,
298
311
}
299
- if 'scheduler' in huggingface_components and hasattr (huggingface_components ['scheduler' ], 'config' ) and 'prediction_type' in huggingface_components ['scheduler' ].config :
300
- huggingface_components ['scheduler' ].config .prediction_type = prediction_types .get (estimated_config .model_type .name , huggingface_components ['scheduler' ].config .prediction_type )
312
+
313
+ has_prediction_type = 'scheduler' in huggingface_components and hasattr (huggingface_components ['scheduler' ], 'config' ) and 'prediction_type' in huggingface_components ['scheduler' ].config
314
+
315
+ if yaml_config is not None :
316
+ model_config_params = config .get ('model' , {}).get ('params' , {})
317
+ if "parameterization" in model_config_params :
318
+ if model_config_params ["parameterization" ] == "v" :
319
+ yaml_config_prediction_type = 'v_prediction'
320
+
321
+ if has_prediction_type :
322
+ if yaml_config_prediction_type is not None :
323
+ huggingface_components ['scheduler' ].config .prediction_type = yaml_config_prediction_type
324
+ else :
325
+ huggingface_components ['scheduler' ].config .prediction_type = prediction_types .get (estimated_config .model_type .name , huggingface_components ['scheduler' ].config .prediction_type )
301
326
302
327
for M in possible_models :
303
328
if any (isinstance (estimated_config , x ) for x in M .matched_guesses ):
0 commit comments