22"""Implementation of model loader service."""
33
44from pathlib import Path
5- from typing import Callable , Optional , Type
5+ from typing import Any , Callable , Optional
66
77from picklescan .scanner import scan_file_path
88from safetensors .torch import load_file as safetensors_load_file
1111from invokeai .app .services .config import InvokeAIAppConfig
1212from invokeai .app .services .invoker import Invoker
1313from invokeai .app .services .model_load .model_load_base import ModelLoadServiceBase
14- from invokeai .backend .model_manager .config import AnyModelConfig
15- from invokeai .backend .model_manager .load import (
16- LoadedModel ,
17- LoadedModelWithoutConfig ,
18- ModelLoaderRegistry ,
19- ModelLoaderRegistryBase ,
20- )
21- from invokeai .backend .model_manager .load .model_cache .model_cache import ModelCache
22- from invokeai .backend .model_manager .load .model_loaders .generic_diffusers import GenericDiffusersLoader
14+ from invokeai .backend .model_manager .config import AnyModelConfig , Diffusers_Config_Base
15+ from invokeai .backend .model_manager .load import LoadedModel , LoadedModelWithoutConfig
16+ from invokeai .backend .model_manager .load .model_cache .model_cache import ModelCache , get_model_cache_key
17+ from invokeai .backend .model_manager .load .model_util import calc_model_size_by_fs
2318from invokeai .backend .model_manager .taxonomy import AnyModel , SubModelType
2419from invokeai .backend .util .devices import TorchDevice
2520from invokeai .backend .util .logging import InvokeAILogger
2621
2722
2823class ModelLoadService (ModelLoadServiceBase ):
29- """Wrapper around ModelLoaderRegistry ."""
24+ """Model loading service using config-based loading ."""
3025
3126 def __init__ (
3227 self ,
3328 app_config : InvokeAIAppConfig ,
3429 ram_cache : ModelCache ,
35- registry : Optional [Type [ModelLoaderRegistryBase ]] = ModelLoaderRegistry ,
3630 ):
3731 """Initialize the model load service."""
3832 logger = InvokeAILogger .get_logger (self .__class__ .__name__ )
3933 logger .setLevel (app_config .log_level .upper ())
4034 self ._logger = logger
4135 self ._app_config = app_config
4236 self ._ram_cache = ram_cache
43- self ._registry = registry
4437
4538 def start (self , invoker : Invoker ) -> None :
4639 self ._invoker = invoker
@@ -63,18 +56,49 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo
6356 if hasattr (self , "_invoker" ):
6457 self ._invoker .services .events .emit_model_load_started (model_config , submodel_type )
6558
66- implementation , model_config , submodel_type = self ._registry .get_implementation (model_config , submodel_type ) # type: ignore
67- loaded_model : LoadedModel = implementation (
68- app_config = self ._app_config ,
69- logger = self ._logger ,
70- ram_cache = self ._ram_cache ,
71- ).load_model (model_config , submodel_type )
59+ loaded_model = self ._load_model_from_config (model_config , submodel_type )
7260
7361 if hasattr (self , "_invoker" ):
7462 self ._invoker .services .events .emit_model_load_complete (model_config , submodel_type )
7563
7664 return loaded_model
7765
66+ def _load_model_from_config (
67+ self , model_config : AnyModelConfig , submodel_type : Optional [SubModelType ] = None
68+ ) -> LoadedModel :
69+ """Load a model using the config's load_model method."""
70+ model_path = Path (model_config .path )
71+ stats_name = ":" .join ([model_config .base , model_config .type , model_config .name , (submodel_type or "" )])
72+
73+ # Check if model is already in cache
74+ try :
75+ cache_record = self ._ram_cache .get (key = get_model_cache_key (model_config .key , submodel_type ), stats_name = stats_name )
76+ return LoadedModel (config = model_config , cache_record = cache_record , cache = self ._ram_cache )
77+ except IndexError :
78+ pass
79+
80+ # Make room in cache
81+ variant = model_config .repo_variant if isinstance (model_config , Diffusers_Config_Base ) else None
82+ model_size = calc_model_size_by_fs (
83+ model_path = model_path ,
84+ subfolder = submodel_type .value if submodel_type else None ,
85+ variant = variant ,
86+ )
87+ self ._ram_cache .make_room (model_size )
88+
89+ # Load the model using the config's load_model method
90+ raw_model = model_config .load_model (submodel_type )
91+
92+ # Cache the loaded model
93+ self ._ram_cache .put (
94+ get_model_cache_key (model_config .key , submodel_type ),
95+ model = raw_model ,
96+ )
97+
98+ # Retrieve from cache and return
99+ cache_record = self ._ram_cache .get (key = get_model_cache_key (model_config .key , submodel_type ), stats_name = stats_name )
100+ return LoadedModel (config = model_config , cache_record = cache_record , cache = self ._ram_cache )
101+
78102 def load_model_from_path (
79103 self , model_path : Path , loader : Optional [Callable [[Path ], AnyModel ]] = None
80104 ) -> LoadedModelWithoutConfig :
@@ -107,12 +131,31 @@ def torch_load_file(checkpoint: Path) -> AnyModel:
107131 return result
108132
109133 def diffusers_load_directory (directory : Path ) -> AnyModel :
110- load_class = GenericDiffusersLoader (
111- app_config = self ._app_config ,
112- logger = self ._logger ,
113- ram_cache = self ._ram_cache ,
114- convert_cache = self .convert_cache ,
115- ).get_hf_load_class (directory )
134+ from diffusers .configuration_utils import ConfigMixin
135+
136+ class ConfigLoader (ConfigMixin ):
137+ """Subclass of ConfigMixin for loading diffusers configuration files."""
138+
139+ @classmethod
140+ def load_config (cls , * args : Any , ** kwargs : Any ) -> dict [str , Any ]: # type: ignore
141+ """Load a diffusers ConfigMixin configuration."""
142+ cls .config_name = kwargs .pop ("config_name" )
143+ return super ().load_config (* args , ** kwargs ) # type: ignore
144+
145+ config = ConfigLoader .load_config (directory , config_name = "config.json" )
146+ if class_name := config .get ("_class_name" ):
147+ import sys
148+
149+ res_type = sys .modules ["diffusers" ]
150+ load_class = getattr (res_type , class_name )
151+ elif class_name := config .get ("architectures" ):
152+ import sys
153+
154+ res_type = sys .modules ["transformers" ]
155+ load_class = getattr (res_type , class_name [0 ])
156+ else :
157+ raise Exception ("Unable to determine load class from config.json" )
158+
116159 return load_class .from_pretrained (model_path , torch_dtype = TorchDevice .choose_torch_dtype ())
117160
118161 loader = loader or (
0 commit comments