Skip to content

Commit b691b1e

Browse files
authored
Fix .yaml config loading (#2224)
1 parent ecd4d28 commit b691b1e

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

backend/loader.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,39 @@ def forge_loader(sd, additional_state_dicts=None):
290290
if component is not None:
291291
huggingface_components[component_name] = component
292292

293-
# Fix Huggingface prediction type using estimated config detection
293+
yaml_config = None
294+
yaml_config_prediction_type = None
295+
296+
try:
297+
import yaml
298+
from pathlib import Path
299+
config_filename = os.path.splitext(sd)[0] + '.yaml'
300+
if Path(config_filename).is_file():
301+
with open(config_filename, 'r') as stream:
302+
yaml_config = yaml.safe_load(stream)
303+
except ImportError:
304+
pass
305+
306+
# Fix Huggingface prediction type using .yaml config or estimated config detection
294307
prediction_types = {
295308
'EPS': 'epsilon',
296309
'V_PREDICTION': 'v_prediction',
297310
'EDM': 'edm',
298311
}
299-
if 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config:
300-
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
312+
313+
has_prediction_type = 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config
314+
315+
if yaml_config is not None:
316+
model_config_params = config.get('model', {}).get('params', {})
317+
if "parameterization" in model_config_params:
318+
if model_config_params["parameterization"] == "v":
319+
yaml_config_prediction_type = 'v_prediction'
320+
321+
if has_prediction_type:
322+
if yaml_config_prediction_type is not None:
323+
huggingface_components['scheduler'].config.prediction_type = yaml_config_prediction_type
324+
else:
325+
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
301326

302327
for M in possible_models:
303328
if any(isinstance(estimated_config, x) for x in M.matched_guesses):

0 commit comments

Comments
 (0)