Skip to content

Conversation

@JunnYu
Copy link
Member

@JunnYu JunnYu commented Nov 15, 2022

PR types

New features

PR changes

Others

Description

https://mp.weixin.qq.com/s/vr5Pw6rc36PwQbP7j9vQYg

DPMSolverMultistepScheduler对齐

from ppdiffusers import DPMSolverMultistepScheduler as PPDPMSolverMultistepScheduler
from diffusers import DPMSolverMultistepScheduler as HFDPMSolverMultistepScheduler
import torch
import paddle
import random

init_kwargs = dict(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
timestep = 1000
device = "cuda"
batch_size = 16

ppeuler = PPDPMSolverMultistepScheduler(**init_kwargs)
hfeuler = HFDPMSolverMultistepScheduler(**init_kwargs)
def compare(a, b):
    dif = a.cpu().numpy() - b.cpu().numpy()
    meandif, maxdif = abs(dif).mean(), abs(dif).max()
    print(meandif, maxdif)

# test set_timesteps
ppeuler.set_timesteps(timestep)
hfeuler.set_timesteps(timestep)
compare(ppeuler.timesteps, hfeuler.timesteps)

# test add_noise
hflatents = torch.randn((batch_size, 4, 64, 64)).to(device)
hfnoise = torch.randn((batch_size, 4, 64, 64)).to(device)
hftimesteps = torch.randint(0, timestep, (batch_size, )).long().to(device)
pplatents = paddle.to_tensor(hflatents.cpu().numpy())
ppnoise = paddle.to_tensor(hfnoise.cpu().numpy())
pptimesteps = paddle.to_tensor(hftimesteps.cpu().numpy())

hfnoisy_latents = hfeuler.add_noise(hflatents, hfnoise, hftimesteps)
ppnoisy_latents = ppeuler.add_noise(pplatents, ppnoise, pptimesteps)

compare(hfnoisy_latents, ppnoisy_latents)

t = random.randint(0, 1000)
hft = torch.tensor(t, device=device)
ppt = paddle.to_tensor(t)
# test scale_model_input
hfscale_model_input = hfeuler.scale_model_input(hflatents, hft)
ppscale_model_input = ppeuler.scale_model_input(pplatents, ppt)

compare(hfscale_model_input, ppscale_model_input)

# test step
seed = random.randint(0, 2**32)
generator = torch.Generator(device).manual_seed(seed)
paddle.seed(seed)
hfstep = hfeuler.step(hfnoise, hft, hflatents)
ppstep = ppeuler.step(ppnoise, ppt, pplatents)
compare(hfstep[0], ppstep[0])
# 0.0 0.0
# 0.0 0.0
# 4.1596826e-07 7.6293945e-06
# 0.0 0.0
# 8.9845884e-08 7.1525574e-07
import torch
torch.set_grad_enabled(False)
from diffusers import StableDiffusionPipeline as HFStableDiffusionPipeline
import paddle
paddle.set_device("gpu")
paddle.set_grad_enabled(False)
from ppdiffusers import StableDiffusionPipeline as PPStableDiffusionPipeline
device = "cuda"
from ppdiffusers import DPMSolverMultistepScheduler as PPDPMSolverMultistepScheduler
from diffusers import DPMSolverMultistepScheduler as HFDPMSolverMultistepScheduler

hf_scheduler = HFDPMSolverMultistepScheduler.from_config(
    "CompVis/stable-diffusion-v1-4",  # or use the v1-5 version
    subfolder="scheduler",
    solver_order=2,
    predict_epsilon=True,
    thresholding=False,
    algorithm_type="dpmsolver++",
    solver_type="midpoint",
    denoise_final=True,  # the influence of this trick is effective for small (e.g. <=10) steps
)
pp_scheduler = PPDPMSolverMultistepScheduler.from_config(
    "CompVis/stable-diffusion-v1-4",  # or use the v1-5 version
    subfolder="scheduler",
    solver_order=2,
    predict_epsilon=True,
    thresholding=False,
    algorithm_type="dpmsolver++",
    solver_type="midpoint",
    denoise_final=True,  # the influence of this trick is effective for small (e.g. <=10) steps
)
hf_pipe = HFStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=hf_scheduler).to(device)
pp_pipe = PPStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=pp_scheduler)

seed = 42
generator = torch.Generator(device).manual_seed(seed)
img = hf_pipe("a pig", generator=generator, num_inference_steps=20)[0][0]
img.save("hf.png")

seed = 42
img = pp_pipe("a pig", seed=seed, num_inference_steps=20)[0][0]
img.save("pp.png")

hf

pp

@westfish
Copy link
Contributor

可不可以把DPMSolverMultistepScheduler的使用方法案例加入ppdiffusers,最好能有和StableDiffusion使用的普通PNDMScheduler对比的效果,或和其他scheduler对比的效果,我看这里只有和hf DPMSolverMultistepScheduler对比的效果

@JunnYu JunnYu merged commit 7a4355c into PaddlePaddle:develop Nov 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants