@@ -402,6 +402,7 @@ def __call__(
402402 guidance_scale : float = 1.0 ,
403403 cfg_start_percent : float = 0.0 ,
404404 cfg_end_percent : float = 1.0 ,
405+ batched_cfg : bool = True ,
405406 num_videos_per_prompt : Optional [int ] = 1 ,
406407 eta : float = 0.0 ,
407408 denoise_strength : float = 1.0 ,
@@ -663,9 +664,9 @@ def __call__(
663664 callback = prepare_callback (self .comfy_model , num_inference_steps )
664665
665666 #print(self.scheduler.sigmas)
666-
667667
668668 logger .info (f"Sampling { video_length } frames in { latents .shape [2 ]} latents at { width } x{ height } with { len (timesteps )} inference steps" )
669+
669670 comfy_pbar = ProgressBar (len (timesteps ))
670671 with self .progress_bar (total = len (timesteps )) as progress_bar :
671672 for i , t in enumerate (timesteps ):
@@ -771,39 +772,72 @@ def __call__(
771772 with torch .autocast (
772773 device_type = "cuda" , dtype = self .base_dtype , enabled = True
773774 ):
774- noise_pred = self .transformer ( # For an input image (129, 192, 336) (1, 256, 256)
775- latent_model_input , # [2, 16, 33, 24, 42]
776- t_expand , # [2]
777- text_states = input_prompt_embeds , # [2, 256, 4096]
778- text_mask = input_prompt_mask , # [2, 256]
779- text_states_2 = input_prompt_embeds_2 , # [2, 768]
780- freqs_cos = freqs_cos , # [seqlen, head_dim]
781- freqs_sin = freqs_sin , # [seqlen, head_dim]
782- guidance = guidance_expand ,
783- stg_block_idx = stg_block_idx ,
784- stg_mode = stg_mode ,
785- return_dict = True ,
786- )["x" ]
787-
788- # perform guidance
789- if cfg_enabled and not self .do_spatio_temporal_guidance :
790- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
791- noise_pred = noise_pred_uncond + self .guidance_scale * (
792- noise_pred_text - noise_pred_uncond
793- )
794- elif self .do_classifier_free_guidance and self .do_spatio_temporal_guidance :
795- raise NotImplementedError
796- noise_pred_uncond , noise_pred_text , noise_pred_perturb = noise_pred .chunk (3 )
797- noise_pred = noise_pred_uncond + self .guidance_scale * (
798- noise_pred_text - noise_pred_uncond
799- ) + self ._stg_scale * (
800- noise_pred_text - noise_pred_perturb
801- )
802- elif self .do_spatio_temporal_guidance and stg_enabled :
803- noise_pred_text , noise_pred_perturb = noise_pred .chunk (2 )
804- noise_pred = noise_pred_text + self ._stg_scale * (
805- noise_pred_text - noise_pred_perturb
806- )
775+ if batched_cfg or not cfg_enabled :
776+ noise_pred = self .transformer ( # For an input image (129, 192, 336) (1, 256, 256)
777+ latent_model_input , # [2, 16, 33, 24, 42]
778+ t_expand , # [2]
779+ text_states = input_prompt_embeds , # [2, 256, 4096]
780+ text_mask = input_prompt_mask , # [2, 256]
781+ text_states_2 = input_prompt_embeds_2 , # [2, 768]
782+ freqs_cos = freqs_cos , # [seqlen, head_dim]
783+ freqs_sin = freqs_sin , # [seqlen, head_dim]
784+ guidance = guidance_expand ,
785+ stg_block_idx = stg_block_idx ,
786+ stg_mode = stg_mode ,
787+ return_dict = True ,
788+ )["x" ]
789+ else :
790+ uncond = self .transformer (
791+ latent_model_input [0 ].unsqueeze (0 ),
792+ t_expand [0 ].unsqueeze (0 ),
793+ text_states = input_prompt_embeds [0 ].unsqueeze (0 ),
794+ text_mask = input_prompt_mask [0 ].unsqueeze (0 ),
795+ text_states_2 = input_prompt_embeds_2 [0 ].unsqueeze (0 ),
796+ freqs_cos = freqs_cos ,
797+ freqs_sin = freqs_sin ,
798+ guidance = guidance_expand [0 ].unsqueeze (0 ),
799+ stg_block_idx = stg_block_idx ,
800+ stg_mode = stg_mode ,
801+ return_dict = True ,
802+ )["x" ]
803+ cond = self .transformer (
804+ latent_model_input [1 ].unsqueeze (0 ),
805+ t_expand [1 ].unsqueeze (0 ),
806+ text_states = input_prompt_embeds [1 ].unsqueeze (0 ),
807+ text_mask = input_prompt_mask [1 ].unsqueeze (0 ),
808+ text_states_2 = input_prompt_embeds_2 [1 ].unsqueeze (0 ),
809+ freqs_cos = freqs_cos ,
810+ freqs_sin = freqs_sin ,
811+ guidance = guidance_expand [1 ].unsqueeze (0 ),
812+ stg_block_idx = stg_block_idx ,
813+ stg_mode = stg_mode ,
814+ return_dict = True ,
815+ )["x" ]
816+
817+ # perform guidance
818+ if cfg_enabled and not self .do_spatio_temporal_guidance :
819+ if batched_cfg :
820+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
821+ noise_pred = noise_pred_uncond + self .guidance_scale * (
822+ noise_pred_text - noise_pred_uncond
823+ )
824+ else :
825+ noise_pred = uncond + self .guidance_scale * (cond - uncond )
826+
827+
828+ elif self .do_classifier_free_guidance and self .do_spatio_temporal_guidance :
829+ raise NotImplementedError
830+ noise_pred_uncond , noise_pred_text , noise_pred_perturb = noise_pred .chunk (3 )
831+ noise_pred = noise_pred_uncond + self .guidance_scale * (
832+ noise_pred_text - noise_pred_uncond
833+ ) + self ._stg_scale * (
834+ noise_pred_text - noise_pred_perturb
835+ )
836+ elif self .do_spatio_temporal_guidance and stg_enabled :
837+ noise_pred_text , noise_pred_perturb = noise_pred .chunk (2 )
838+ noise_pred = noise_pred_text + self ._stg_scale * (
839+ noise_pred_text - noise_pred_perturb
840+ )
807841
808842 # compute the previous noisy sample x_t -> x_t-1
809843 latents = self .scheduler .step (
0 commit comments