Skip to content

Commit 035f2af

Browse files
authored
Merge branch 'AUTOMATIC1111:master' into fix/alternating-words-emphasis
2 parents 7e45fba + 45a8b75 commit 035f2af

File tree

6 files changed

+46
-7
lines changed

6 files changed

+46
-7
lines changed

javascript/imageviewer.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ function showGalleryImage() {
151151
e.addEventListener('mousedown', function (evt) {
152152
if(!opts.js_modal_lightbox || evt.button != 0) return;
153153
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
154+
evt.preventDefault()
154155
showModal(evt)
155156
}, true);
156157
}

modules/extras.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
import traceback
6+
import shutil
67

78
import numpy as np
89
from PIL import Image
@@ -248,7 +249,32 @@ def run_pnginfo(image):
248249
return '', geninfo, info
249250

250251

251-
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
252+
def create_config(ckpt_result, config_source, a, b, c):
253+
def config(x):
254+
return sd_models.find_checkpoint_config(x) if x else None
255+
256+
if config_source == 0:
257+
cfg = config(a) or config(b) or config(c)
258+
elif config_source == 1:
259+
cfg = config(b)
260+
elif config_source == 2:
261+
cfg = config(c)
262+
else:
263+
cfg = None
264+
265+
if cfg is None:
266+
return
267+
268+
filename, _ = os.path.splitext(ckpt_result)
269+
checkpoint_filename = filename + ".yaml"
270+
271+
print("Copying config:")
272+
print(" from:", cfg)
273+
print(" to:", checkpoint_filename)
274+
shutil.copyfile(cfg, checkpoint_filename)
275+
276+
277+
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
252278
shared.state.begin()
253279
shared.state.job = 'model-merge'
254280

@@ -356,6 +382,8 @@ def add_difference(theta0, theta1_2_diff, alpha):
356382

357383
sd_models.list_models()
358384

385+
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
386+
359387
print("Checkpoint saved.")
360388
shared.state.textinfo = "Checkpoint saved to " + output_modelname
361389
shared.state.end()

modules/scripts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def basedir():
152152

153153
scripts_data = []
154154
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
155-
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
155+
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
156156

157157

158158
def list_scripts(scriptdirname, extension):
@@ -206,7 +206,7 @@ def load_scripts():
206206

207207
for key, script_class in module.__dict__.items():
208208
if type(script_class) == type and issubclass(script_class, Script):
209-
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
209+
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
210210

211211
except Exception:
212212
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@@ -241,7 +241,7 @@ def initialize_scripts(self, is_img2img):
241241
self.alwayson_scripts.clear()
242242
self.selectable_scripts.clear()
243243

244-
for script_class, path, basedir in scripts_data:
244+
for script_class, path, basedir, script_module in scripts_data:
245245
script = script_class()
246246
script.filename = path
247247
script.is_txt2img = not is_img2img

modules/sd_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,14 @@ def load_model(checkpoint_info=None):
333333

334334
timer = Timer()
335335

336+
sd_model = None
336337
try:
337338
with sd_disable_initialization.DisableInitialization():
338339
sd_model = instantiate_from_config(sd_config.model)
339340
except Exception as e:
341+
pass
342+
343+
if sd_model is None:
340344
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
341345
sd_model = instantiate_from_config(sd_config.model)
342346

modules/textual_inversion/textual_inversion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import html
1010
import datetime
1111
import csv
12+
import safetensors.torch
1213

1314
from PIL import Image, PngImagePlugin
1415

@@ -150,6 +151,8 @@ def load_from_file(self, path, filename):
150151
name = data.get('name', name)
151152
elif ext in ['.BIN', '.PT']:
152153
data = torch.load(path, map_location="cpu")
154+
elif ext in ['.SAFETENSORS']:
155+
data = safetensors.torch.load_file(path, device="cpu")
153156
else:
154157
return
155158

modules/ui.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ def update_orig(image, state):
11291129
with gr.Column(variant='panel'):
11301130
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
11311131

1132-
with gr.Row():
1132+
with FormRow():
11331133
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
11341134
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
11351135

@@ -1143,11 +1143,13 @@ def update_orig(image, state):
11431143
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
11441144
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
11451145

1146-
with gr.Row():
1146+
with FormRow():
11471147
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
11481148
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
11491149

1150-
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
1150+
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
1151+
1152+
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
11511153

11521154
with gr.Column(variant='panel'):
11531155
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
@@ -1703,6 +1705,7 @@ def modelmerger(*args):
17031705
save_as_half,
17041706
custom_name,
17051707
checkpoint_format,
1708+
config_source,
17061709
],
17071710
outputs=[
17081711
submit_result,

0 commit comments

Comments
 (0)