66import math
77import logging
88import comfy .sampler_helpers
9+ import scipy
10+ import numpy
911
1012def get_area_and_mult (conds , x_in , timestep_in ):
1113 dims = tuple (x_in .shape [2 :])
@@ -337,6 +339,18 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
337339 sigs += [0.0 ]
338340 return torch .FloatTensor (sigs )
339341
342+ # Implemented based on: https://arxiv.org/abs/2407.12173
343+ def beta_scheduler (model_sampling , steps , alpha = 0.6 , beta = 0.6 ):
344+ total_timesteps = (len (model_sampling .sigmas ) - 1 )
345+ ts = 1 - numpy .linspace (0 , 1 , steps , endpoint = False )
346+ ts = numpy .rint (scipy .stats .beta .ppf (ts , alpha , beta ) * total_timesteps )
347+
348+ sigs = []
349+ for t in ts :
350+ sigs += [float (model_sampling .sigmas [int (t )])]
351+ sigs += [0.0 ]
352+ return torch .FloatTensor (sigs )
353+
340354def get_mask_aabb (masks ):
341355 if masks .numel () == 0 :
342356 return torch .zeros ((0 , 4 ), device = masks .device , dtype = torch .int )
@@ -703,7 +717,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
703717 return cfg_guider .sample (noise , latent_image , sampler , sigmas , denoise_mask , callback , disable_pbar , seed )
704718
705719
706- SCHEDULER_NAMES = ["normal" , "karras" , "exponential" , "sgm_uniform" , "simple" , "ddim_uniform" ]
720+ SCHEDULER_NAMES = ["normal" , "karras" , "exponential" , "sgm_uniform" , "simple" , "ddim_uniform" , "beta" ]
707721SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim" , "uni_pc" , "uni_pc_bh2" ]
708722
709723def calculate_sigmas (model_sampling , scheduler_name , steps ):
@@ -719,6 +733,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
719733 sigmas = ddim_scheduler (model_sampling , steps )
720734 elif scheduler_name == "sgm_uniform" :
721735 sigmas = normal_scheduler (model_sampling , steps , sgm = True )
736+ elif scheduler_name == "beta" :
737+ sigmas = beta_scheduler (model_sampling , steps )
722738 else :
723739 logging .error ("error invalid scheduler {}" .format (scheduler_name ))
724740 return sigmas
0 commit comments