Skip to content

Commit a475ec2

Browse files
Cleanup HunyuanDit controlnets.
Use the: ControlNetApply SD3 and HunyuanDiT node.
1 parent 06eb9fb commit a475ec2

File tree

4 files changed

+60
-194
lines changed

4 files changed

+60
-194
lines changed

comfy/controlnet.py

Lines changed: 42 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,24 @@
1+
"""
2+
This file is part of ComfyUI.
3+
Copyright (C) 2024 Comfy
4+
5+
This program is free software: you can redistribute it and/or modify
6+
it under the terms of the GNU General Public License as published by
7+
the Free Software Foundation, either version 3 of the License, or
8+
(at your option) any later version.
9+
10+
This program is distributed in the hope that it will be useful,
11+
but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
GNU General Public License for more details.
14+
15+
You should have received a copy of the GNU General Public License
16+
along with this program. If not, see <https://www.gnu.org/licenses/>.
17+
"""
18+
19+
120
import torch
21+
from enum import Enum
222
import math
323
import os
424
import logging
@@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
3353
else:
3454
return torch.cat([tensor] * batched_number, dim=0)
3555

56+
class StrengthType(Enum):
57+
CONSTANT = 1
58+
LINEAR_UP = 2
59+
3660
class ControlBase:
3761
def __init__(self, device=None):
3862
self.cond_hint_original = None
@@ -51,6 +75,8 @@ def __init__(self, device=None):
5175
device = comfy.model_management.get_torch_device()
5276
self.device = device
5377
self.previous_controlnet = None
78+
self.extra_conds = []
79+
self.strength_type = StrengthType.CONSTANT
5480

5581
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
5682
self.cond_hint_original = cond_hint
@@ -93,6 +119,8 @@ def copy_to(self, c):
93119
c.latent_format = self.latent_format
94120
c.extra_args = self.extra_args.copy()
95121
c.vae = self.vae
122+
c.extra_conds = self.extra_conds.copy()
123+
c.strength_type = self.strength_type
96124

97125
def inference_memory_requirements(self, dtype):
98126
if self.previous_controlnet is not None:
@@ -113,7 +141,10 @@ def control_merge(self, control, control_prev, output_dtype):
113141

114142
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
115143
applied_to.add(x)
116-
x *= self.strength
144+
if self.strength_type == StrengthType.CONSTANT:
145+
x *= self.strength
146+
elif self.strength_type == StrengthType.LINEAR_UP:
147+
x *= (self.strength ** float(len(control_output) - i))
117148

118149
if x.dtype != output_dtype:
119150
x = x.to(output_dtype)
@@ -142,7 +173,7 @@ def set_extra_arg(self, argument, value=None):
142173

143174

144175
class ControlNet(ControlBase):
145-
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
176+
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=[], strength_type=StrengthType.CONSTANT):
146177
super().__init__(device)
147178
self.control_model = control_model
148179
self.load_device = load_device
@@ -154,6 +185,8 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
154185
self.model_sampling_current = None
155186
self.manual_cast_dtype = manual_cast_dtype
156187
self.latent_format = latent_format
188+
self.extra_conds += extra_conds
189+
self.strength_type = strength_type
157190

158191
def get_control(self, x_noisy, t, cond, batched_number):
159192
control_prev = None
@@ -192,7 +225,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
192225

193226
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
194227
extra = self.extra_args.copy()
195-
for c in ["y", "guidance"]: #TODO
228+
for c in self.extra_conds:
196229
temp = cond.get(c, None)
197230
if temp is not None:
198231
extra[c] = temp.to(dtype)
@@ -382,116 +415,22 @@ def load_controlnet_mmdit(sd):
382415
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
383416
return control
384417

385-
class ControlNetWarperHunyuanDiT(ControlNet):
386-
def get_control(self, x_noisy, t, cond, batched_number):
387-
control_prev = None
388-
if self.previous_controlnet is not None:
389-
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
390-
391-
if self.timestep_range is not None:
392-
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
393-
if control_prev is not None:
394-
return control_prev
395-
else:
396-
return None
418+
def load_controlnet_hunyuandit(controlnet_data):
419+
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
397420

398-
dtype = self.control_model.dtype
399-
if self.manual_cast_dtype is not None:
400-
dtype = self.manual_cast_dtype
401-
402-
output_dtype = x_noisy.dtype
403-
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
404-
if self.cond_hint is not None:
405-
del self.cond_hint
406-
self.cond_hint = None
407-
compression_ratio = self.compression_ratio
408-
if self.vae is not None:
409-
compression_ratio *= self.vae.downscale_ratio
410-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
411-
if self.vae is not None:
412-
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
413-
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
414-
comfy.model_management.load_models_gpu(loaded_models)
415-
if self.latent_format is not None:
416-
self.cond_hint = self.latent_format.process_in(self.cond_hint)
417-
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
418-
if x_noisy.shape[0] != self.cond_hint.shape[0]:
419-
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
420-
421-
def get_tensor(name):
422-
if name in cond:
423-
if isinstance(cond[name], torch.Tensor):
424-
return cond[name].to(dtype)
425-
else:
426-
return cond[name]
427-
else:
428-
return None
429-
430-
encoder_hidden_states = get_tensor('c_crossattn')
431-
text_embedding_mask = get_tensor('text_embedding_mask')
432-
encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5')
433-
text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5')
434-
image_meta_size = get_tensor('image_meta_size')
435-
style = get_tensor('style')
436-
cos_cis_img = get_tensor('cos_cis_img')
437-
sin_cis_img = get_tensor('sin_cis_img')
438-
439-
timestep = self.model_sampling_current.timestep(t)
440-
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
441-
442-
control = self.control_model(
443-
x=x_noisy.to(dtype),
444-
t=timestep.float(),
445-
condition=self.cond_hint,
446-
encoder_hidden_states=encoder_hidden_states,
447-
text_embedding_mask=text_embedding_mask,
448-
encoder_hidden_states_t5=encoder_hidden_states_t5,
449-
text_embedding_mask_t5=text_embedding_mask_t5,
450-
image_meta_size=image_meta_size,
451-
style=style,
452-
cos_cis_img=cos_cis_img,
453-
sin_cis_img=sin_cis_img,
454-
**self.extra_args
455-
)
456-
return self.control_merge(control, control_prev, output_dtype)
457-
458-
def copy(self):
459-
c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
460-
c.control_model = self.control_model
461-
c.control_model_wrapped = self.control_model_wrapped
462-
self.copy_to(c)
463-
return c
464-
465-
def load_controlnet_hunyuandit(controlnet_data):
466-
467-
supported_inference_dtypes = [torch.float16, torch.float32]
468-
469-
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
470-
load_device = comfy.model_management.get_torch_device()
471-
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
472-
if manual_cast_dtype is not None:
473-
operations = comfy.ops.manual_cast
474-
else:
475-
operations = comfy.ops.disable_weight_init
476-
477421
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
478-
missing, unexpected = control_model.load_state_dict(controlnet_data)
479-
480-
if len(missing) > 0:
481-
logging.warning("missing controlnet keys: {}".format(missing))
482-
483-
if len(unexpected) > 0:
484-
logging.debug("unexpected controlnet keys: {}".format(unexpected))
422+
control_model = controlnet_load_state_dict(control_model, controlnet_data)
485423

486424
latent_format = comfy.latent_formats.SDXL()
487-
control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
425+
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
426+
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.LINEAR_UP)
488427
return control
489428

490429
def load_controlnet(ckpt_path, model=None):
491430
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
492431
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
493432
return load_controlnet_hunyuandit(controlnet_data)
494-
433+
495434
if "lora_controlnet" in controlnet_data:
496435
return ControlLora(controlnet_data)
497436

comfy/ldm/hydit/controlnet.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,11 @@
1616
from .poolers import AttentionPool
1717

1818
import comfy.latent_formats
19-
from .models import HunYuanDiTBlock
19+
from .models import HunYuanDiTBlock, calc_rope
2020

2121
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
2222

2323

24-
def zero_module(module):
25-
for p in module.parameters():
26-
nn.init.zeros_(p)
27-
return module
28-
29-
30-
def calc_rope(x, patch_size, head_size):
31-
th = (x.shape[2] + (patch_size // 2)) // patch_size
32-
tw = (x.shape[3] + (patch_size // 2)) // patch_size
33-
base_size = 512 // 8 // patch_size
34-
start, stop = get_fill_resize_and_crop((th, tw), base_size)
35-
sub_args = [start, stop, (th, tw)]
36-
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
37-
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
38-
return rope
39-
40-
4124
class HunYuanControlNet(nn.Module):
4225
"""
4326
HunYuanDiT: Diffusion model with a Transformer backbone.
@@ -213,35 +196,32 @@ def __init__(
213196
)
214197

215198
# Input zero linear for the first block
216-
self.before_proj = zero_module(
217-
nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
218-
)
199+
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
200+
219201

220202
# Output zero linear for the every block
221203
self.after_proj_list = nn.ModuleList(
222204
[
223-
zero_module(
224-
nn.Linear(
205+
206+
operations.Linear(
225207
self.hidden_size, self.hidden_size, dtype=dtype, device=device
226208
)
227-
)
228209
for _ in range(len(self.blocks))
229210
]
230211
)
231212

232213
def forward(
233214
self,
234-
x: torch.Tensor,
235-
t: torch.Tensor = None,
236-
condition=None,
237-
encoder_hidden_states: Optional[torch.Tensor] = None,
215+
x,
216+
hint,
217+
timesteps,
218+
context,#encoder_hidden_states=None,
238219
text_embedding_mask=None,
239220
encoder_hidden_states_t5=None,
240221
text_embedding_mask_t5=None,
241222
image_meta_size=None,
242223
style=None,
243-
control_weight=1.0,
244-
transformer_options=None,
224+
return_dict=False,
245225
**kwarg,
246226
):
247227
"""
@@ -270,10 +250,11 @@ def forward(
270250
return_dict: bool
271251
Whether to return a dictionary.
272252
"""
253+
condition = hint
273254
if condition.shape[0] == 1:
274255
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
275256

276-
text_states = encoder_hidden_states # 2,77,1024
257+
text_states = context # 2,77,1024
277258
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
278259
text_states_mask = text_embedding_mask.bool() # 2,77
279260
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
@@ -304,7 +285,7 @@ def forward(
304285
) # (cos_cis_img, sin_cis_img)
305286

306287
# ========================= Build time and image embedding =========================
307-
t = self.t_embedder(t, dtype=self.dtype)
288+
t = self.t_embedder(timesteps, dtype=self.dtype)
308289
x = self.x_embedder(x)
309290

310291
# ========================= Concatenate all extra vectors =========================
@@ -337,12 +318,4 @@ def forward(
337318
x = block(x, c, text_states, freqs_cis_img)
338319
controls.append(self.after_proj_list[layer](x)) # zero linear for output
339320

340-
control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)]
341-
assert len(control_weights) == len(
342-
controls
343-
), "control_weights and controls should have the same length"
344-
controls = [
345-
control * weight for control, weight in zip(controls, control_weights)
346-
]
347-
348321
return {"output": controls}

comfy_extras/nodes_hunyuan.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,58 +19,7 @@ def encode(self, clip, bert, mt5xl):
1919
cond = output.pop("cond")
2020
return ([[cond, output]], )
2121

22-
23-
class ControlNetApplyAdvancedHunYuan:
24-
@classmethod
25-
def INPUT_TYPES(s):
26-
return {"required": {"positive": ("CONDITIONING", ),
27-
"negative": ("CONDITIONING", ),
28-
"control_net": ("CONTROL_NET", ),
29-
"image": ("IMAGE", ),
30-
"vae": ("VAE", ),
31-
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
32-
"control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}),
33-
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
34-
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
35-
}}
36-
37-
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
38-
RETURN_NAMES = ("positive", "negative")
39-
FUNCTION = "apply_controlnet"
40-
41-
CATEGORY = "conditioning/controlnet"
42-
43-
def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None):
44-
if strength == 0:
45-
return (positive, negative)
46-
47-
control_hint = image.movedim(-1,1)
48-
cnets = {}
49-
50-
out = []
51-
for conditioning in [positive, negative]:
52-
c = []
53-
for t in conditioning:
54-
d = t[1].copy()
55-
56-
prev_cnet = d.get('control', None)
57-
if prev_cnet in cnets:
58-
c_net = cnets[prev_cnet]
59-
else:
60-
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
61-
c_net.set_extra_arg('control_weight', control_weight)
62-
63-
c_net.set_previous_controlnet(prev_cnet)
64-
cnets[prev_cnet] = c_net
6522

66-
d['control'] = c_net
67-
d['control_apply_to_uncond'] = False
68-
n = [t[0], d]
69-
c.append(n)
70-
out.append(c)
71-
return (out[0], out[1])
72-
7323
NODE_CLASS_MAPPINGS = {
7424
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
75-
"ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan,
7625
}

comfy_extras/nodes_sd3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,8 @@ def INPUT_TYPES(s):
100100
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
101101
"ControlNetApplySD3": ControlNetApplySD3,
102102
}
103+
104+
NODE_DISPLAY_NAME_MAPPINGS = {
105+
# Sampling
106+
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
107+
}

0 commit comments

Comments
 (0)