Skip to content

Commit ed373ed

Browse files
authored
Add IPAdapter advanced weighting [Instant Style] (#2767)
* rename variable * nit * add transformer index * nit * API support * Fix transformer_index alignment issue * Add weight control panel * fix sdxl * free edit * clear advanced weighting * nit * Add assertion
1 parent fa72f3e commit ed373ed

File tree

9 files changed

+636
-81
lines changed

9 files changed

+636
-81
lines changed

internal_controlnet/external_code.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ class ControlNetUnit:
195195
# For T2IAdapter
196196
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
197197
# - SDXL: 4 weights (3 encoder block + 1 middle block)
198+
# For IPAdapter
199+
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
200+
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
198201
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
199202
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
200203
# Note2: The field `weight` is still used in some places, e.g. reference_only,

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ opencv-python>=4.8.0
44
svglib
55
addict
66
yapf
7-
albumentations==1.4.3
7+
albumentations==1.4.3
8+
matplotlib

scripts/controlnet.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ def controlnet_main_entry(self, p):
927927
logger.info(f"unit_separate = {batch_option_uint_separate}, style_align = {batch_option_style_align}")
928928

929929
detected_maps = []
930-
forward_params = []
930+
forward_params: List[ControlParams] = []
931931
post_processors = []
932932

933933
# cache stuff
@@ -1177,10 +1177,21 @@ def recolor_intensity_post_processing(x, i):
11771177

11781178
for i, param in enumerate(forward_params):
11791179
if param.control_model_type == ControlModelType.IPAdapter:
1180+
if param.advanced_weighting is not None:
1181+
logger.info(f"IP-Adapter using advanced weighting {param.advanced_weighting}")
1182+
assert len(param.advanced_weighting) == global_state.get_sd_version().transformer_block_num
1183+
# Convert advanced weighting list to dict
1184+
weight = {
1185+
i: w
1186+
for i, w in enumerate(param.advanced_weighting)
1187+
if w > 0
1188+
}
1189+
else:
1190+
weight = param.weight
11801191
param.control_model.hook(
11811192
model=unet,
11821193
preprocessor_outputs=param.hint_cond,
1183-
weight=param.weight,
1194+
weight=weight,
11841195
dtype=torch.float32,
11851196
start=param.start_guidance_percent,
11861197
end=param.stop_guidance_percent
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import gradio as gr
2+
import matplotlib.pyplot as plt
3+
from matplotlib.patches import Patch
4+
import io
5+
import json
6+
from PIL import Image
7+
from typing import List
8+
9+
from scripts.enums import StableDiffusionVersion
10+
from scripts.global_state import get_sd_version
11+
from scripts.ipadapter.weight import calc_weights
12+
13+
14+
INPUT_BLOCK_COLOR = "#61bdee"
15+
MIDDLE_BLOCK_COLOR = "#e2e2e2"
16+
OUTPUT_BLOCK_COLOR = "#dc6e55"
17+
18+
19+
def get_bar_colors(
20+
sd_version: StableDiffusionVersion, input_color, middle_color, output_color
21+
):
22+
middle_block_idx = 4 if sd_version == StableDiffusionVersion.SDXL else 6
23+
24+
def get_color(idx):
25+
if idx < middle_block_idx:
26+
return input_color
27+
elif idx == middle_block_idx:
28+
return middle_color
29+
else:
30+
return output_color
31+
32+
return [get_color(i) for i in range(sd_version.transformer_block_num)]
33+
34+
35+
def plot_weights(
36+
numbers: List[float],
37+
colors: List[str],
38+
):
39+
# Create a bar chart
40+
plt.figure(figsize=(8, 4))
41+
plt.bar(range(len(numbers)), numbers, color=colors)
42+
plt.xlabel("Transformer Index")
43+
plt.ylabel("Weight")
44+
plt.legend(
45+
handles=[
46+
Patch(color=color, label=label)
47+
for color, label in (
48+
(INPUT_BLOCK_COLOR, "Input Block"),
49+
(MIDDLE_BLOCK_COLOR, "Middle Block"),
50+
(OUTPUT_BLOCK_COLOR, "Output Block"),
51+
)
52+
],
53+
loc="best",
54+
)
55+
56+
# Save the plot to a BytesIO buffer
57+
buffer = io.BytesIO()
58+
plt.savefig(buffer, format="png")
59+
plt.close()
60+
buffer.seek(0)
61+
62+
# Convert the buffer to a PIL image and return it
63+
image = Image.open(buffer)
64+
return image
65+
66+
67+
class AdvancedWeightControl:
68+
def __init__(self):
69+
self.group = None
70+
self.weight_type = None
71+
self.weight_plot = None
72+
self.weight_editor = None
73+
self.weight_composition = None
74+
75+
def render(self):
76+
with gr.Group(visible=False) as self.group:
77+
with gr.Row():
78+
self.weight_type = gr.Dropdown(
79+
choices=[
80+
"normal",
81+
"ease in",
82+
"ease out",
83+
"ease in-out",
84+
"reverse in-out",
85+
"weak input",
86+
"weak output",
87+
"weak middle",
88+
"strong middle",
89+
"style transfer",
90+
"composition",
91+
"strong style transfer",
92+
"style and composition",
93+
"strong style and composition",
94+
],
95+
label="Weight Type",
96+
value="normal",
97+
)
98+
self.weight_composition = gr.Slider(
99+
label="Composition Weight",
100+
minimum=0,
101+
maximum=2.0,
102+
value=1.0,
103+
step=0.01,
104+
visible=False,
105+
)
106+
self.weight_editor = gr.Textbox(label="Weights", visible=False)
107+
108+
self.weight_plot = gr.Image(
109+
value=None,
110+
label="Weight Plot",
111+
interactive=False,
112+
visible=False,
113+
)
114+
115+
def register_callbacks(
116+
self,
117+
weight_input: gr.Slider,
118+
advanced_weighting: gr.State,
119+
control_type: gr.Radio,
120+
update_unit_counter: gr.Number,
121+
):
122+
def advanced_weighting_supported(control_type: str) -> bool:
123+
return control_type in ("IP-Adapter", "Instant-ID")
124+
125+
self.weight_type.change(
126+
fn=lambda weight_type: gr.update(
127+
visible=weight_type
128+
in ("style and composition", "strong style and composition")
129+
),
130+
inputs=[self.weight_type],
131+
outputs=[self.weight_composition],
132+
)
133+
134+
def update_weight_textbox(
135+
control_type: str,
136+
weight_type: str,
137+
weight: float,
138+
weight_composition: float,
139+
):
140+
if not advanced_weighting_supported(control_type):
141+
return gr.update()
142+
143+
sd_version = get_sd_version()
144+
weights = calc_weights(weight_type, weight, sd_version, weight_composition)
145+
return gr.update(value=str([round(w, 2) for w in weights]), visible=True)
146+
147+
trigger_inputs = [self.weight_type, weight_input, self.weight_composition]
148+
for trigger_input in trigger_inputs:
149+
trigger_input.change(
150+
fn=update_weight_textbox,
151+
inputs=[
152+
control_type,
153+
self.weight_type,
154+
weight_input,
155+
self.weight_composition,
156+
],
157+
outputs=[self.weight_editor],
158+
)
159+
160+
def update_plot(weights_string: str):
161+
try:
162+
weights = json.loads(weights_string)
163+
assert isinstance(weights, list)
164+
except Exception:
165+
return gr.update(visible=False)
166+
167+
sd_version = get_sd_version()
168+
weight_plot = plot_weights(
169+
weights,
170+
get_bar_colors(
171+
sd_version,
172+
input_color=INPUT_BLOCK_COLOR,
173+
middle_color=MIDDLE_BLOCK_COLOR,
174+
output_color=OUTPUT_BLOCK_COLOR,
175+
),
176+
)
177+
return gr.update(value=weight_plot, visible=True)
178+
179+
def update_advanced_weighting(weights_string: str):
180+
try:
181+
weights = json.loads(weights_string)
182+
assert isinstance(weights, list)
183+
except Exception:
184+
return None
185+
return weights
186+
187+
self.weight_editor.change(
188+
fn=update_plot,
189+
inputs=[self.weight_editor],
190+
outputs=[self.weight_plot],
191+
)
192+
193+
self.weight_editor.change(
194+
fn=update_advanced_weighting,
195+
inputs=[self.weight_editor],
196+
outputs=[advanced_weighting],
197+
).then(
198+
fn=lambda x: gr.update(value=x + 1),
199+
inputs=[update_unit_counter],
200+
outputs=[update_unit_counter],
201+
) # Necessary to flush gr.State change to unit state.
202+
203+
# TODO: Expose advanced weighting control for other control types.
204+
def control_type_change(control_type: str, old_weights):
205+
supported = advanced_weighting_supported(control_type)
206+
if supported:
207+
return (
208+
gr.update(visible=supported),
209+
old_weights,
210+
gr.update(),
211+
gr.update(),
212+
)
213+
else:
214+
return (
215+
gr.update(visible=supported),
216+
None,
217+
gr.update(visible=False),
218+
gr.update(visible=False),
219+
)
220+
221+
control_type.change(
222+
fn=control_type_change,
223+
inputs=[control_type, advanced_weighting],
224+
outputs=[
225+
self.group,
226+
advanced_weighting,
227+
self.weight_editor,
228+
self.weight_plot,
229+
],
230+
)

scripts/controlnet_ui/controlnet_ui_group.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from scripts.controlnet_ui.preset import ControlNetPresetUI
1919
from scripts.controlnet_ui.tool_button import ToolButton
2020
from scripts.controlnet_ui.photopea import Photopea
21+
from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl
2122
from scripts.enums import InputMode
2223
from modules import shared
2324
from modules.ui_components import FormRow
@@ -289,6 +290,7 @@ def __init__(
289290
self.input_mode = gr.State(InputMode.SIMPLE)
290291
self.inpaint_crop_input_image = None
291292
self.hr_option = None
293+
self.advanced_weight_control = AdvancedWeightControl()
292294
self.batch_image_dir_state = None
293295
self.output_dir_state = None
294296

@@ -640,6 +642,8 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
640642
visible=False,
641643
)
642644

645+
self.advanced_weight_control.render()
646+
643647
self.preset_panel = ControlNetPresetUI(
644648
id_prefix=f"{elem_id_tabname}_{tabname}_"
645649
)
@@ -675,6 +679,8 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
675679
self.control_mode,
676680
self.inpaint_crop_input_image,
677681
self.hr_option,
682+
self.save_detected_map,
683+
self.advanced_weighting,
678684
)
679685

680686
unit = gr.State(self.default_unit)
@@ -1256,6 +1262,12 @@ def register_core_callbacks(self):
12561262
for key in vars(external_code.ControlNetUnit()).keys()
12571263
],
12581264
)
1265+
self.advanced_weight_control.register_callbacks(
1266+
self.weight,
1267+
self.advanced_weighting,
1268+
self.type_filter,
1269+
self.update_unit_counter,
1270+
)
12591271
if self.is_img2img:
12601272
self.register_img2img_same_input()
12611273

0 commit comments

Comments
 (0)