1+ """
2+ This file is part of ComfyUI.
3+ Copyright (C) 2024 Comfy
4+
5+ This program is free software: you can redistribute it and/or modify
6+ it under the terms of the GNU General Public License as published by
7+ the Free Software Foundation, either version 3 of the License, or
8+ (at your option) any later version.
9+
10+ This program is distributed in the hope that it will be useful,
11+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+ GNU General Public License for more details.
14+
15+ You should have received a copy of the GNU General Public License
16+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17+ """
18+
19+
120import torch
21+ from enum import Enum
222import math
323import os
424import logging
@@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
3353 else :
3454 return torch .cat ([tensor ] * batched_number , dim = 0 )
3555
56+ class StrengthType (Enum ):
57+ CONSTANT = 1
58+ LINEAR_UP = 2
59+
3660class ControlBase :
3761 def __init__ (self , device = None ):
3862 self .cond_hint_original = None
@@ -51,6 +75,8 @@ def __init__(self, device=None):
5175 device = comfy .model_management .get_torch_device ()
5276 self .device = device
5377 self .previous_controlnet = None
78+ self .extra_conds = []
79+ self .strength_type = StrengthType .CONSTANT
5480
5581 def set_cond_hint (self , cond_hint , strength = 1.0 , timestep_percent_range = (0.0 , 1.0 ), vae = None ):
5682 self .cond_hint_original = cond_hint
@@ -93,6 +119,8 @@ def copy_to(self, c):
93119 c .latent_format = self .latent_format
94120 c .extra_args = self .extra_args .copy ()
95121 c .vae = self .vae
122+ c .extra_conds = self .extra_conds .copy ()
123+ c .strength_type = self .strength_type
96124
97125 def inference_memory_requirements (self , dtype ):
98126 if self .previous_controlnet is not None :
@@ -113,7 +141,10 @@ def control_merge(self, control, control_prev, output_dtype):
113141
114142 if x not in applied_to : #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
115143 applied_to .add (x )
116- x *= self .strength
144+ if self .strength_type == StrengthType .CONSTANT :
145+ x *= self .strength
146+ elif self .strength_type == StrengthType .LINEAR_UP :
147+ x *= (self .strength ** float (len (control_output ) - i ))
117148
118149 if x .dtype != output_dtype :
119150 x = x .to (output_dtype )
@@ -142,7 +173,7 @@ def set_extra_arg(self, argument, value=None):
142173
143174
144175class ControlNet (ControlBase ):
145- def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None ):
176+ def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None , extra_conds = [], strength_type = StrengthType . CONSTANT ):
146177 super ().__init__ (device )
147178 self .control_model = control_model
148179 self .load_device = load_device
@@ -154,6 +185,8 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
154185 self .model_sampling_current = None
155186 self .manual_cast_dtype = manual_cast_dtype
156187 self .latent_format = latent_format
188+ self .extra_conds += extra_conds
189+ self .strength_type = strength_type
157190
158191 def get_control (self , x_noisy , t , cond , batched_number ):
159192 control_prev = None
@@ -192,7 +225,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
192225
193226 context = cond .get ('crossattn_controlnet' , cond ['c_crossattn' ])
194227 extra = self .extra_args .copy ()
195- for c in [ "y" , "guidance" ]: #TODO
228+ for c in self . extra_conds :
196229 temp = cond .get (c , None )
197230 if temp is not None :
198231 extra [c ] = temp .to (dtype )
@@ -382,116 +415,22 @@ def load_controlnet_mmdit(sd):
382415 control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
383416 return control
384417
385- class ControlNetWarperHunyuanDiT (ControlNet ):
386- def get_control (self , x_noisy , t , cond , batched_number ):
387- control_prev = None
388- if self .previous_controlnet is not None :
389- control_prev = self .previous_controlnet .get_control (x_noisy , t , cond , batched_number )
390-
391- if self .timestep_range is not None :
392- if t [0 ] > self .timestep_range [0 ] or t [0 ] < self .timestep_range [1 ]:
393- if control_prev is not None :
394- return control_prev
395- else :
396- return None
418+ def load_controlnet_hunyuandit (controlnet_data ):
419+ model_config , operations , load_device , unet_dtype , manual_cast_dtype = controlnet_config (controlnet_data )
397420
398- dtype = self .control_model .dtype
399- if self .manual_cast_dtype is not None :
400- dtype = self .manual_cast_dtype
401-
402- output_dtype = x_noisy .dtype
403- if self .cond_hint is None or x_noisy .shape [2 ] * self .compression_ratio != self .cond_hint .shape [2 ] or x_noisy .shape [3 ] * self .compression_ratio != self .cond_hint .shape [3 ]:
404- if self .cond_hint is not None :
405- del self .cond_hint
406- self .cond_hint = None
407- compression_ratio = self .compression_ratio
408- if self .vae is not None :
409- compression_ratio *= self .vae .downscale_ratio
410- self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * compression_ratio , x_noisy .shape [2 ] * compression_ratio , self .upscale_algorithm , "center" )
411- if self .vae is not None :
412- loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
413- self .cond_hint = self .vae .encode (self .cond_hint .movedim (1 , - 1 ))
414- comfy .model_management .load_models_gpu (loaded_models )
415- if self .latent_format is not None :
416- self .cond_hint = self .latent_format .process_in (self .cond_hint )
417- self .cond_hint = self .cond_hint .to (device = self .device , dtype = dtype )
418- if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
419- self .cond_hint = broadcast_image_to (self .cond_hint , x_noisy .shape [0 ], batched_number )
420-
421- def get_tensor (name ):
422- if name in cond :
423- if isinstance (cond [name ], torch .Tensor ):
424- return cond [name ].to (dtype )
425- else :
426- return cond [name ]
427- else :
428- return None
429-
430- encoder_hidden_states = get_tensor ('c_crossattn' )
431- text_embedding_mask = get_tensor ('text_embedding_mask' )
432- encoder_hidden_states_t5 = get_tensor ('encoder_hidden_states_t5' )
433- text_embedding_mask_t5 = get_tensor ('text_embedding_mask_t5' )
434- image_meta_size = get_tensor ('image_meta_size' )
435- style = get_tensor ('style' )
436- cos_cis_img = get_tensor ('cos_cis_img' )
437- sin_cis_img = get_tensor ('sin_cis_img' )
438-
439- timestep = self .model_sampling_current .timestep (t )
440- x_noisy = self .model_sampling_current .calculate_input (t , x_noisy )
441-
442- control = self .control_model (
443- x = x_noisy .to (dtype ),
444- t = timestep .float (),
445- condition = self .cond_hint ,
446- encoder_hidden_states = encoder_hidden_states ,
447- text_embedding_mask = text_embedding_mask ,
448- encoder_hidden_states_t5 = encoder_hidden_states_t5 ,
449- text_embedding_mask_t5 = text_embedding_mask_t5 ,
450- image_meta_size = image_meta_size ,
451- style = style ,
452- cos_cis_img = cos_cis_img ,
453- sin_cis_img = sin_cis_img ,
454- ** self .extra_args
455- )
456- return self .control_merge (control , control_prev , output_dtype )
457-
458- def copy (self ):
459- c = ControlNetWarperHunyuanDiT (None , global_average_pooling = self .global_average_pooling , load_device = self .load_device , manual_cast_dtype = self .manual_cast_dtype )
460- c .control_model = self .control_model
461- c .control_model_wrapped = self .control_model_wrapped
462- self .copy_to (c )
463- return c
464-
465- def load_controlnet_hunyuandit (controlnet_data ):
466-
467- supported_inference_dtypes = [torch .float16 , torch .float32 ]
468-
469- unet_dtype = comfy .model_management .unet_dtype (supported_dtypes = supported_inference_dtypes )
470- load_device = comfy .model_management .get_torch_device ()
471- manual_cast_dtype = comfy .model_management .unet_manual_cast (unet_dtype , load_device )
472- if manual_cast_dtype is not None :
473- operations = comfy .ops .manual_cast
474- else :
475- operations = comfy .ops .disable_weight_init
476-
477421 control_model = comfy .ldm .hydit .controlnet .HunYuanControlNet (operations = operations , device = load_device , dtype = unet_dtype )
478- missing , unexpected = control_model .load_state_dict (controlnet_data )
479-
480- if len (missing ) > 0 :
481- logging .warning ("missing controlnet keys: {}" .format (missing ))
482-
483- if len (unexpected ) > 0 :
484- logging .debug ("unexpected controlnet keys: {}" .format (unexpected ))
422+ control_model = controlnet_load_state_dict (control_model , controlnet_data )
485423
486424 latent_format = comfy .latent_formats .SDXL ()
487- control = ControlNetWarperHunyuanDiT (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
425+ extra_conds = ['text_embedding_mask' , 'encoder_hidden_states_t5' , 'text_embedding_mask_t5' , 'image_meta_size' , 'style' , 'cos_cis_img' , 'sin_cis_img' ]
426+ control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds , strength_type = StrengthType .LINEAR_UP )
488427 return control
489428
490429def load_controlnet (ckpt_path , model = None ):
491430 controlnet_data = comfy .utils .load_torch_file (ckpt_path , safe_load = True )
492431 if 'after_proj_list.18.bias' in controlnet_data .keys (): #Hunyuan DiT
493432 return load_controlnet_hunyuandit (controlnet_data )
494-
433+
495434 if "lora_controlnet" in controlnet_data :
496435 return ControlLora (controlnet_data )
497436
0 commit comments