Skip to content

Commit ca849a7

Browse files
sbinneeyiyixuxu
authored andcommitted
Allow DDPMPipeline half precision (huggingface#9222)
Co-authored-by: YiYi Xu <[email protected]>
1 parent 482e7a2 commit ca849a7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ def __call__(
101101

102102
if self.device.type == "mps":
103103
# randn does not work reproducibly on mps
104-
image = randn_tensor(image_shape, generator=generator)
104+
image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
105105
image = image.to(self.device)
106106
else:
107-
image = randn_tensor(image_shape, generator=generator, device=self.device)
107+
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
108108

109109
# set step values
110110
self.scheduler.set_timesteps(num_inference_steps)

0 commit comments

Comments
 (0)