1717from modules import shared , devices , sd_models , errors , scripts , sd_hijack
1818from modules .textual_inversion .textual_inversion import Embedding
1919
20+ from lora_logger import logger
21+
2022module_types = [
2123 network_lora .ModuleTypeLora (),
2224 network_hada .ModuleTypeHada (),
@@ -206,7 +208,40 @@ def load_network(name, network_on_disk):
206208
207209 net .modules [key ] = net_module
208210
209- net .bundle_embeddings = bundle_embeddings
211+ embeddings = {}
212+ for emb_name , data in bundle_embeddings .items ():
213+ # textual inversion embeddings
214+ if 'string_to_param' in data :
215+ param_dict = data ['string_to_param' ]
216+ param_dict = getattr (param_dict , '_parameters' , param_dict ) # fix for torch 1.12.1 loading saved file from torch 1.11
217+ assert len (param_dict ) == 1 , 'embedding file has multiple terms in it'
218+ emb = next (iter (param_dict .items ()))[1 ]
219+ vec = emb .detach ().to (devices .device , dtype = torch .float32 )
220+ shape = vec .shape [- 1 ]
221+ vectors = vec .shape [0 ]
222+ elif type (data ) == dict and 'clip_g' in data and 'clip_l' in data : # SDXL embedding
223+ vec = {k : v .detach ().to (devices .device , dtype = torch .float32 ) for k , v in data .items ()}
224+ shape = data ['clip_g' ].shape [- 1 ] + data ['clip_l' ].shape [- 1 ]
225+ vectors = data ['clip_g' ].shape [0 ]
226+ elif type (data ) == dict and type (next (iter (data .values ()))) == torch .Tensor : # diffuser concepts
227+ assert len (data .keys ()) == 1 , 'embedding file has multiple terms in it'
228+
229+ emb = next (iter (data .values ()))
230+ if len (emb .shape ) == 1 :
231+ emb = emb .unsqueeze (0 )
232+ vec = emb .detach ().to (devices .device , dtype = torch .float32 )
233+ shape = vec .shape [- 1 ]
234+ vectors = vec .shape [0 ]
235+ else :
236+ raise Exception (f"Couldn't identify { emb_name } in lora: { name } as neither textual inversion embedding nor diffuser concept." )
237+
238+ embedding = Embedding (vec , emb_name )
239+ embedding .vectors = vectors
240+ embedding .shape = shape
241+ embedding .loaded = None
242+ embeddings [emb_name ] = embedding
243+
244+ net .bundle_embeddings = embeddings
210245
211246 if keys_failed_to_match :
212247 logging .debug (f"Network { network_on_disk .filename } didn't match keys: { keys_failed_to_match } " )
@@ -229,8 +264,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
229264 for net in loaded_networks :
230265 if net .name in names :
231266 already_loaded [net .name ] = net
232- for emb_name in net .bundle_embeddings :
233- emb_db .register_embedding_by_name (None , shared .sd_model , emb_name )
267+ for emb_name , embedding in net .bundle_embeddings .items ():
268+ if embedding .loaded :
269+ emb_db .register_embedding_by_name (None , shared .sd_model , emb_name )
234270
235271 loaded_networks .clear ()
236272
@@ -273,37 +309,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
273309 net .dyn_dim = dyn_dims [i ] if dyn_dims else 1.0
274310 loaded_networks .append (net )
275311
276- for emb_name , data in net .bundle_embeddings .items ():
277- # textual inversion embeddings
278- if 'string_to_param' in data :
279- param_dict = data ['string_to_param' ]
280- param_dict = getattr (param_dict , '_parameters' , param_dict ) # fix for torch 1.12.1 loading saved file from torch 1.11
281- assert len (param_dict ) == 1 , 'embedding file has multiple terms in it'
282- emb = next (iter (param_dict .items ()))[1 ]
283- vec = emb .detach ().to (devices .device , dtype = torch .float32 )
284- shape = vec .shape [- 1 ]
285- vectors = vec .shape [0 ]
286- elif type (data ) == dict and 'clip_g' in data and 'clip_l' in data : # SDXL embedding
287- vec = {k : v .detach ().to (devices .device , dtype = torch .float32 ) for k , v in data .items ()}
288- shape = data ['clip_g' ].shape [- 1 ] + data ['clip_l' ].shape [- 1 ]
289- vectors = data ['clip_g' ].shape [0 ]
290- elif type (data ) == dict and type (next (iter (data .values ()))) == torch .Tensor : # diffuser concepts
291- assert len (data .keys ()) == 1 , 'embedding file has multiple terms in it'
292-
293- emb = next (iter (data .values ()))
294- if len (emb .shape ) == 1 :
295- emb = emb .unsqueeze (0 )
296- vec = emb .detach ().to (devices .device , dtype = torch .float32 )
297- shape = vec .shape [- 1 ]
298- vectors = vec .shape [0 ]
299- else :
300- raise Exception (f"Couldn't identify { emb_name } in lora: { name } as neither textual inversion embedding nor diffuser concept." )
301-
302- embedding = Embedding (vec , emb_name )
303- embedding .vectors = vectors
304- embedding .shape = shape
312+ for emb_name , embedding in net .bundle_embeddings .items ():
313+ if embedding .loaded is None and emb_name in emb_db .word_embeddings :
314+ logger .warning (
315+ f'Skip bundle embedding: "{ emb_name } "'
316+ ' as it was already loaded from embeddings folder'
317+ )
318+ continue
305319
320+ embedding .loaded = False
306321 if emb_db .expected_shape == - 1 or emb_db .expected_shape == embedding .shape :
322+ embedding .loaded = True
307323 emb_db .register_embedding (embedding , shared .sd_model )
308324 else :
309325 emb_db .skipped_embeddings [name ] = embedding
0 commit comments