Skip to content

Commit 2997d65

Browse files
committed
add option to do cfg sequentially
1 parent e490745 commit 2997d65

File tree

2 files changed

+76
-35
lines changed

2 files changed

+76
-35
lines changed

hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ def __call__(
402402
guidance_scale: float = 1.0,
403403
cfg_start_percent: float = 0.0,
404404
cfg_end_percent: float = 1.0,
405+
batched_cfg: bool = True,
405406
num_videos_per_prompt: Optional[int] = 1,
406407
eta: float = 0.0,
407408
denoise_strength: float = 1.0,
@@ -663,9 +664,9 @@ def __call__(
663664
callback = prepare_callback(self.comfy_model, num_inference_steps)
664665

665666
#print(self.scheduler.sigmas)
666-
667667

668668
logger.info(f"Sampling {video_length} frames in {latents.shape[2]} latents at {width}x{height} with {len(timesteps)} inference steps")
669+
669670
comfy_pbar = ProgressBar(len(timesteps))
670671
with self.progress_bar(total=len(timesteps)) as progress_bar:
671672
for i, t in enumerate(timesteps):
@@ -771,39 +772,72 @@ def __call__(
771772
with torch.autocast(
772773
device_type="cuda", dtype=self.base_dtype, enabled=True
773774
):
774-
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
775-
latent_model_input, # [2, 16, 33, 24, 42]
776-
t_expand, # [2]
777-
text_states=input_prompt_embeds, # [2, 256, 4096]
778-
text_mask=input_prompt_mask, # [2, 256]
779-
text_states_2=input_prompt_embeds_2, # [2, 768]
780-
freqs_cos=freqs_cos, # [seqlen, head_dim]
781-
freqs_sin=freqs_sin, # [seqlen, head_dim]
782-
guidance=guidance_expand,
783-
stg_block_idx=stg_block_idx,
784-
stg_mode=stg_mode,
785-
return_dict=True,
786-
)["x"]
787-
788-
# perform guidance
789-
if cfg_enabled and not self.do_spatio_temporal_guidance:
790-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
791-
noise_pred = noise_pred_uncond + self.guidance_scale * (
792-
noise_pred_text - noise_pred_uncond
793-
)
794-
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
795-
raise NotImplementedError
796-
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
797-
noise_pred = noise_pred_uncond + self.guidance_scale * (
798-
noise_pred_text - noise_pred_uncond
799-
) + self._stg_scale * (
800-
noise_pred_text - noise_pred_perturb
801-
)
802-
elif self.do_spatio_temporal_guidance and stg_enabled:
803-
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
804-
noise_pred = noise_pred_text + self._stg_scale * (
805-
noise_pred_text - noise_pred_perturb
806-
)
775+
if batched_cfg or not cfg_enabled:
776+
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
777+
latent_model_input, # [2, 16, 33, 24, 42]
778+
t_expand, # [2]
779+
text_states=input_prompt_embeds, # [2, 256, 4096]
780+
text_mask=input_prompt_mask, # [2, 256]
781+
text_states_2=input_prompt_embeds_2, # [2, 768]
782+
freqs_cos=freqs_cos, # [seqlen, head_dim]
783+
freqs_sin=freqs_sin, # [seqlen, head_dim]
784+
guidance=guidance_expand,
785+
stg_block_idx=stg_block_idx,
786+
stg_mode=stg_mode,
787+
return_dict=True,
788+
)["x"]
789+
else:
790+
uncond = self.transformer(
791+
latent_model_input[0].unsqueeze(0),
792+
t_expand[0].unsqueeze(0),
793+
text_states=input_prompt_embeds[0].unsqueeze(0),
794+
text_mask=input_prompt_mask[0].unsqueeze(0),
795+
text_states_2=input_prompt_embeds_2[0].unsqueeze(0),
796+
freqs_cos=freqs_cos,
797+
freqs_sin=freqs_sin,
798+
guidance=guidance_expand[0].unsqueeze(0),
799+
stg_block_idx=stg_block_idx,
800+
stg_mode=stg_mode,
801+
return_dict=True,
802+
)["x"]
803+
cond = self.transformer(
804+
latent_model_input[1].unsqueeze(0),
805+
t_expand[1].unsqueeze(0),
806+
text_states=input_prompt_embeds[1].unsqueeze(0),
807+
text_mask=input_prompt_mask[1].unsqueeze(0),
808+
text_states_2=input_prompt_embeds_2[1].unsqueeze(0),
809+
freqs_cos=freqs_cos,
810+
freqs_sin=freqs_sin,
811+
guidance=guidance_expand[1].unsqueeze(0),
812+
stg_block_idx=stg_block_idx,
813+
stg_mode=stg_mode,
814+
return_dict=True,
815+
)["x"]
816+
817+
# perform guidance
818+
if cfg_enabled and not self.do_spatio_temporal_guidance:
819+
if batched_cfg:
820+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
821+
noise_pred = noise_pred_uncond + self.guidance_scale * (
822+
noise_pred_text - noise_pred_uncond
823+
)
824+
else:
825+
noise_pred = uncond + self.guidance_scale * (cond - uncond)
826+
827+
828+
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
829+
raise NotImplementedError
830+
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
831+
noise_pred = noise_pred_uncond + self.guidance_scale * (
832+
noise_pred_text - noise_pred_uncond
833+
) + self._stg_scale * (
834+
noise_pred_text - noise_pred_perturb
835+
)
836+
elif self.do_spatio_temporal_guidance and stg_enabled:
837+
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
838+
noise_pred = noise_pred_text + self._stg_scale * (
839+
noise_pred_text - noise_pred_perturb
840+
)
807841

808842
# compute the previous noisy sample x_t -> x_t-1
809843
latents = self.scheduler.step(

nodes.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,7 @@ def encode_prompt(self, prompt, negative_prompt, text_encoder, image_token_selec
965965
"cfg": torch.tensor(hyvid_cfg["cfg"]) if hyvid_cfg is not None else None,
966966
"start_percent": torch.tensor(hyvid_cfg["start_percent"]) if hyvid_cfg is not None else None,
967967
"end_percent": torch.tensor(hyvid_cfg["end_percent"]) if hyvid_cfg is not None else None,
968+
"batched_cfg": torch.tensor(hyvid_cfg["batched_cfg"]) if hyvid_cfg is not None else None,
968969
}
969970
return (prompt_embeds_dict,)
970971

@@ -1003,6 +1004,7 @@ def INPUT_TYPES(s):
10031004
"cfg": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "guidance scale"} ),
10041005
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply CFG, rest of the steps use guidance_embeds"} ),
10051006
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply CFG, rest of the steps use guidance_embeds"} ),
1007+
"batched_cfg": ("BOOLEAN", {"default": True, "tooltip": "Calculate cond and uncond as a batch, increases memory usage but can be faster"}),
10061008
},
10071009
}
10081010

@@ -1012,12 +1014,13 @@ def INPUT_TYPES(s):
10121014
CATEGORY = "HunyuanVideoWrapper"
10131015
DESCRIPTION = "To use CFG with HunyuanVideo"
10141016

1015-
def process(self, negative_prompt, cfg, start_percent, end_percent):
1017+
def process(self, negative_prompt, cfg, start_percent, end_percent, batched_cfg):
10161018
cfg_dict = {
10171019
"negative_prompt": negative_prompt,
10181020
"cfg": cfg,
10191021
"start_percent": start_percent,
10201022
"end_percent": end_percent,
1023+
"batched_cfg": batched_cfg
10211024
}
10221025

10231026
return (cfg_dict,)
@@ -1095,6 +1098,7 @@ def load(self, embeds):
10951098
"cfg": loaded_tensors.get("cfg", None),
10961099
"start_percent": loaded_tensors.get("start_percent", None),
10971100
"end_percent": loaded_tensors.get("end_percent", None),
1101+
"batched_cfg": loaded_tensors.get("batched_cfg", None),
10981102
}
10991103

11001104
return (prompt_embeds_dict,)
@@ -1185,10 +1189,12 @@ def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scal
11851189
cfg = float(hyvid_embeds.get("cfg", 1.0))
11861190
cfg_start_percent = float(hyvid_embeds.get("start_percent", 0.0))
11871191
cfg_end_percent = float(hyvid_embeds.get("end_percent", 1.0))
1192+
batched_cfg = hyvid_embeds.get("batched_cfg", True)
11881193
else:
11891194
cfg = 1.0
11901195
cfg_start_percent = 0.0
11911196
cfg_end_percent = 1.0
1197+
batched_cfg = False
11921198

11931199
if embedded_guidance_scale == 0.0:
11941200
embedded_guidance_scale = None
@@ -1291,6 +1297,7 @@ def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scal
12911297
guidance_scale=cfg,
12921298
cfg_start_percent=cfg_start_percent,
12931299
cfg_end_percent=cfg_end_percent,
1300+
batched_cfg=batched_cfg,
12941301
embedded_guidance_scale=embedded_guidance_scale,
12951302
latents=input_latents,
12961303
denoise_strength=denoise_strength,

0 commit comments

Comments
 (0)