Skip to content

Commit 81e94de

Browse files
Add warning when meet emb name conflicting
Choose standalone embedding (in /embeddings folder) first
1 parent 2282eb8 commit 81e94de

File tree

2 files changed

+81
-32
lines changed

2 files changed

+81
-32
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import sys
2+
import copy
3+
import logging
4+
5+
6+
class ColoredFormatter(logging.Formatter):
7+
COLORS = {
8+
"DEBUG": "\033[0;36m", # CYAN
9+
"INFO": "\033[0;32m", # GREEN
10+
"WARNING": "\033[0;33m", # YELLOW
11+
"ERROR": "\033[0;31m", # RED
12+
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
13+
"RESET": "\033[0m", # RESET COLOR
14+
}
15+
16+
def format(self, record):
17+
colored_record = copy.copy(record)
18+
levelname = colored_record.levelname
19+
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
20+
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
21+
return super().format(colored_record)
22+
23+
24+
logger = logging.getLogger("lora")
25+
logger.propagate = False
26+
27+
28+
if not logger.handlers:
29+
handler = logging.StreamHandler(sys.stdout)
30+
handler.setFormatter(
31+
ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
32+
)
33+
logger.addHandler(handler)

extensions-builtin/Lora/networks.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
1818
from modules.textual_inversion.textual_inversion import Embedding
1919

20+
from lora_logger import logger
21+
2022
module_types = [
2123
network_lora.ModuleTypeLora(),
2224
network_hada.ModuleTypeHada(),
@@ -206,7 +208,40 @@ def load_network(name, network_on_disk):
206208

207209
net.modules[key] = net_module
208210

209-
net.bundle_embeddings = bundle_embeddings
211+
embeddings = {}
212+
for emb_name, data in bundle_embeddings.items():
213+
# textual inversion embeddings
214+
if 'string_to_param' in data:
215+
param_dict = data['string_to_param']
216+
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
217+
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
218+
emb = next(iter(param_dict.items()))[1]
219+
vec = emb.detach().to(devices.device, dtype=torch.float32)
220+
shape = vec.shape[-1]
221+
vectors = vec.shape[0]
222+
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
223+
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
224+
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
225+
vectors = data['clip_g'].shape[0]
226+
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
227+
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
228+
229+
emb = next(iter(data.values()))
230+
if len(emb.shape) == 1:
231+
emb = emb.unsqueeze(0)
232+
vec = emb.detach().to(devices.device, dtype=torch.float32)
233+
shape = vec.shape[-1]
234+
vectors = vec.shape[0]
235+
else:
236+
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
237+
238+
embedding = Embedding(vec, emb_name)
239+
embedding.vectors = vectors
240+
embedding.shape = shape
241+
embedding.loaded = None
242+
embeddings[emb_name] = embedding
243+
244+
net.bundle_embeddings = embeddings
210245

211246
if keys_failed_to_match:
212247
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -229,8 +264,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
229264
for net in loaded_networks:
230265
if net.name in names:
231266
already_loaded[net.name] = net
232-
for emb_name in net.bundle_embeddings:
233-
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
267+
for emb_name, embedding in net.bundle_embeddings.items():
268+
if embedding.loaded:
269+
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
234270

235271
loaded_networks.clear()
236272

@@ -273,37 +309,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
273309
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
274310
loaded_networks.append(net)
275311

276-
for emb_name, data in net.bundle_embeddings.items():
277-
# textual inversion embeddings
278-
if 'string_to_param' in data:
279-
param_dict = data['string_to_param']
280-
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
281-
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
282-
emb = next(iter(param_dict.items()))[1]
283-
vec = emb.detach().to(devices.device, dtype=torch.float32)
284-
shape = vec.shape[-1]
285-
vectors = vec.shape[0]
286-
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
287-
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
288-
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
289-
vectors = data['clip_g'].shape[0]
290-
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
291-
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
292-
293-
emb = next(iter(data.values()))
294-
if len(emb.shape) == 1:
295-
emb = emb.unsqueeze(0)
296-
vec = emb.detach().to(devices.device, dtype=torch.float32)
297-
shape = vec.shape[-1]
298-
vectors = vec.shape[0]
299-
else:
300-
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
301-
302-
embedding = Embedding(vec, emb_name)
303-
embedding.vectors = vectors
304-
embedding.shape = shape
312+
for emb_name, embedding in net.bundle_embeddings.items():
313+
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
314+
logger.warning(
315+
f'Skip bundle embedding: "{emb_name}"'
316+
' as it was already loaded from embeddings folder'
317+
)
318+
continue
305319

320+
embedding.loaded = False
306321
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
322+
embedding.loaded = True
307323
emb_db.register_embedding(embedding, shared.sd_model)
308324
else:
309325
emb_db.skipped_embeddings[name] = embedding

0 commit comments

Comments
 (0)