1313# See the License for the specific language governing permissions and 
1414# limitations under the License. 
1515
16+ import  inspect 
1617from  typing  import  Callable , List , Optional , Union 
1718
1819import  paddle 
2829)
2930from  ppdiffusers  import  (
3031    AutoencoderKL ,
32+     DDIMScheduler ,
3133    DiffusionPipeline ,
3234    LMSDiscreteScheduler ,
3335    PNDMScheduler ,
@@ -84,7 +86,7 @@ def __init__(
8486        clip_model : CLIPModel ,
8587        tokenizer : CLIPTokenizer ,
8688        unet : UNet2DConditionModel ,
87-         scheduler : Union [PNDMScheduler , LMSDiscreteScheduler ],
89+         scheduler : Union [PNDMScheduler , LMSDiscreteScheduler ,  DDIMScheduler ],
8890        feature_extractor : CLIPFeatureExtractor ,
8991    ):
9092        super ().__init__ ()
@@ -99,7 +101,12 @@ def __init__(
99101        )
100102
101103        self .normalize  =  transforms .Normalize (mean = feature_extractor .image_mean , std = feature_extractor .image_std )
102-         self .make_cutouts  =  MakeCutouts (feature_extractor .size )
104+         self .cut_out_size  =  (
105+             feature_extractor .size 
106+             if  isinstance (feature_extractor .size , int )
107+             else  feature_extractor .size ["shortest_edge" ]
108+         )
109+         self .make_cutouts  =  MakeCutouts (self .cut_out_size )
103110
104111        set_stop_gradient (self .text_encoder , True )
105112        set_stop_gradient (self .clip_model , True )
@@ -152,7 +159,7 @@ def cond_fn(
152159        # predict the noise residual 
153160        noise_pred  =  self .unet (latent_model_input , timestep , encoder_hidden_states = text_embeddings ).sample 
154161
155-         if  isinstance (self .scheduler , PNDMScheduler ):
162+         if  isinstance (self .scheduler , ( PNDMScheduler ,  DDIMScheduler ) ):
156163            alpha_prod_t  =  self .scheduler .alphas_cumprod [timestep ]
157164            beta_prod_t  =  1  -  alpha_prod_t 
158165            # compute predicted original sample from predicted noise also called 
@@ -174,7 +181,7 @@ def cond_fn(
174181        if  use_cutouts :
175182            image  =  self .make_cutouts (image , num_cutouts )
176183        else :
177-             resize_transform  =  transforms .Resize (self .feature_extractor . size )
184+             resize_transform  =  transforms .Resize (self .cut_out_size )
178185            image  =  paddle .stack ([resize_transform (img ) for  img  in  image ], axis = 0 )
179186        image  =  self .normalize (image ).astype (latents .dtype )
180187
@@ -207,11 +214,12 @@ def __call__(
207214        guidance_scale : Optional [float ] =  7.5 ,
208215        negative_prompt : Optional [Union [str , List [str ]]] =  None ,
209216        num_images_per_prompt : Optional [int ] =  1 ,
217+         eta : float  =  0.0 ,
210218        clip_guidance_scale : Optional [float ] =  100 ,
211219        clip_prompt : Optional [Union [str , List [str ]]] =  None ,
212220        num_cutouts : Optional [int ] =  4 ,
213221        use_cutouts : Optional [bool ] =  True ,
214-         seed : Optional [int ] =  None ,
222+         generator : Optional [paddle . Generator ] =  None ,
215223        latents : Optional [paddle .Tensor ] =  None ,
216224        output_type : Optional [str ] =  "pil" ,
217225        return_dict : bool  =  True ,
@@ -277,9 +285,9 @@ def __call__(
277285            text_embeddings_clip  =  self .clip_model .get_text_features (clip_text_input_ids )
278286            text_embeddings_clip  =  text_embeddings_clip  /  text_embeddings_clip .norm (p = 2 , axis = - 1 , keepdim = True )
279287            # duplicate text embeddings clip for each generation per prompt 
280-             bs_embed , seq_len ,  _  =  text_embeddings .shape 
281-             text_embeddings_clip  =  text_embeddings_clip .tile ([1 , num_images_per_prompt ,  1 ])
282-             text_embeddings_clip  =  text_embeddings_clip .reshape ([bs_embed  *  num_images_per_prompt , seq_len ,  - 1 ])
288+             bs_embed , _  =  text_embeddings_clip .shape 
289+             text_embeddings_clip  =  text_embeddings_clip .tile ([1 , num_images_per_prompt ])
290+             text_embeddings_clip  =  text_embeddings_clip .reshape ([bs_embed  *  num_images_per_prompt , - 1 ])
283291
284292        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 
285293        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 
@@ -334,8 +342,7 @@ def __call__(
334342        # However this currently doesn't work in `mps`. 
335343        latents_shape  =  [batch_size  *  num_images_per_prompt , self .unet .in_channels , height  //  8 , width  //  8 ]
336344        if  latents  is  None :
337-             paddle .seed (seed )
338-             latents  =  paddle .randn (latents_shape , dtype = text_embeddings .dtype )
345+             latents  =  paddle .randn (latents_shape , generator = generator , dtype = text_embeddings .dtype )
339346        else :
340347            if  latents .shape  !=  latents_shape :
341348                raise  ValueError (f"Unexpected latents shape, got { latents .shape }  , expected { latents_shape }  " )
@@ -350,6 +357,20 @@ def __call__(
350357        # scale the initial noise by the standard deviation required by the scheduler 
351358        latents  =  latents  *  self .scheduler .init_noise_sigma 
352359
360+         # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 
361+         # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 
362+         # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 
363+         # and should be between [0, 1] 
364+         accepts_eta  =  "eta"  in  set (inspect .signature (self .scheduler .step ).parameters .keys ())
365+         extra_step_kwargs  =  {}
366+         if  accepts_eta :
367+             extra_step_kwargs ["eta" ] =  eta 
368+ 
369+         # check if the scheduler accepts generator 
370+         accepts_generator  =  "generator"  in  set (inspect .signature (self .scheduler .step ).parameters .keys ())
371+         if  accepts_generator :
372+             extra_step_kwargs ["generator" ] =  generator 
373+ 
353374        for  i , t  in  enumerate (self .progress_bar (timesteps_tensor )):
354375            # expand the latents if we are doing classifier free guidance 
355376            latent_model_input  =  paddle .concat ([latents ] *  2 ) if  do_classifier_free_guidance  else  latents 
@@ -381,7 +402,7 @@ def __call__(
381402                )
382403
383404            # compute the previous noisy sample x_t -> x_t-1 
384-             latents  =  self .scheduler .step (noise_pred , t , latents ).prev_sample 
405+             latents  =  self .scheduler .step (noise_pred , t , latents ,  ** extra_step_kwargs ).prev_sample 
385406
386407            # call the callback, if provided 
387408            if  callback  is  not   None  and  i  %  callback_steps  ==  0 :
0 commit comments