@@ -76,6 +76,14 @@ def kl_optimal(n, sigma_min, sigma_max, device):
7676 sigmas = torch .tan (step_indices / n * alpha_min + (1.0 - step_indices / n ) * alpha_max )
7777 return sigmas
7878
79+ def simple_scheduler (n , sigma_min , sigma_max , inner_model , device ):
80+ sigs = []
81+ ss = len (inner_model .sigmas ) / n
82+ for x in range (n ):
83+ sigs += [float (inner_model .sigmas [- (1 + int (x * ss ))])]
84+ sigs += [0.0 ]
85+ return torch .FloatTensor (sigs ).to (device )
86+
7987
8088def vp (n , sigma_min , sigma_max , inner_model , device ):
8189 beta_d = shared .opts .data .get ("vp_beta_d" , 19.9 )
@@ -223,6 +231,7 @@ def loglinear_interp(t_steps, num_steps):
223231 Scheduler ('uniform' , 'Uniform' , uniform , need_inner_model = True ),
224232 Scheduler ('sgm_uniform' , 'SGM Uniform' , sgm_uniform , need_inner_model = True , aliases = ["SGMUniform" ]),
225233 Scheduler ('kl_optimal' , 'KL Optimal' , kl_optimal ),
234+ Scheduler ('simple' , 'Simple' , simple_scheduler , need_inner_model = True ),
226235 Scheduler ('align_your_steps' , 'Align Your Steps' , get_align_your_steps_sigmas ),
227236 Scheduler ('align_your_steps_GITS' , 'Align Your Steps GITS' , get_align_your_steps_sigmas_GITS ),
228237 Scheduler ('align_your_steps_11' , 'Align Your Steps 11' , ays_11_sigmas ),
0 commit comments