@@ -56,11 +56,21 @@ def parse_arguments():
5656 type = str ,
5757 default = "paddle" ,
5858 # Note(zhoushunjie): Will support 'tensorrt', 'paddle-tensorrt' soon.
59+ choices = ["onnx_runtime" , "paddle" , "paddlelite" ],
60+ help = "The inference runtime backend of unet model and text encoder model." ,
61+ )
62+ parser .add_argument (
63+ "--device" ,
64+ type = str ,
65+ default = "gpu" ,
66+ # Note(shentanyue): Will support more devices.
5967 choices = [
60- "onnx_runtime" ,
61- "paddle" ,
68+ "cpu" ,
69+ "gpu" ,
70+ "huawei_ascend_npu" ,
71+ "kunlunxin_xpu" ,
6272 ],
63- help = "The inference runtime backend of unet model and text encoder model ." ,
73+ help = "The inference runtime device of models ." ,
6474 )
6575 parser .add_argument (
6676 "--image_path" , default = "fd_astronaut_rides_horse.png" , help = "The model directory of diffusion_model."
@@ -123,6 +133,25 @@ def create_paddle_inference_runtime(
123133 return fd .Runtime (option )
124134
125135
136+ def create_paddle_lite_runtime (model_dir , model_prefix , device = "cpu" , device_id = 0 ):
137+ option = fd .RuntimeOption ()
138+ option .use_lite_backend ()
139+ if device == "huawei_ascend_npu" :
140+ option .use_cann ()
141+ option .set_lite_nnadapter_device_names (["huawei_ascend_npu" ])
142+ option .set_lite_nnadapter_model_cache_dir (os .path .join (model_dir , model_prefix ))
143+ option .set_lite_nnadapter_context_properties ("HUAWEI_ASCEND_NPU_SELECTED_DEVICE_IDS={}" .format (device_id ))
144+ elif device == "kunlunxin_xpu" :
145+ # TODO(shentanyue): Add kunlunxin_xpu code
146+ pass
147+ else :
148+ pass
149+ model_file = os .path .join (model_dir , model_prefix , "inference.pdmodel" )
150+ params_file = os .path .join (model_dir , model_prefix , "inference.pdiparams" )
151+ option .set_model_path (model_file , params_file )
152+ return fd .Runtime (option )
153+
154+
126155def create_trt_runtime (model_dir , model_prefix , model_format , workspace = (1 << 31 ), dynamic_shape = None , device_id = 0 ):
127156 option = fd .RuntimeOption ()
128157 option .use_trt_backend ()
@@ -210,42 +239,45 @@ def get_scheduler(args):
210239 }
211240
212241 # 4. Init runtime
242+ device_id = args .device_id
243+ if args .device == "cpu" :
244+ device_id = - 1
213245 if args .backend == "onnx_runtime" :
214246 text_encoder_runtime = create_ort_runtime (
215- args .model_dir , args .text_encoder_model_prefix , args .model_format , device_id = args . device_id
247+ args .model_dir , args .text_encoder_model_prefix , args .model_format , device_id = device_id
216248 )
217249 vae_decoder_runtime = create_ort_runtime (
218- args .model_dir , args .vae_decoder_model_prefix , args .model_format , device_id = args . device_id
250+ args .model_dir , args .vae_decoder_model_prefix , args .model_format , device_id = device_id
219251 )
220252 vae_encoder_runtime = create_ort_runtime (
221- args .model_dir , args .vae_encoder_model_prefix , args .model_format , device_id = args . device_id
253+ args .model_dir , args .vae_encoder_model_prefix , args .model_format , device_id = device_id
222254 )
223255 start = time .time ()
224256 unet_runtime = create_ort_runtime (
225- args .model_dir , args .unet_model_prefix , args .model_format , device_id = args . device_id
257+ args .model_dir , args .unet_model_prefix , args .model_format , device_id = device_id
226258 )
227259 print (f"Spend { time .time () - start : .2f} s to load unet model." )
228260 elif args .backend == "paddle" or args .backend == "paddle-tensorrt" :
229261 use_trt = True if args .backend == "paddle-tensorrt" else False
230262 # Note(zhoushunjie): Will change to paddle runtime later
231263 text_encoder_runtime = create_ort_runtime (
232- args .model_dir , args .text_encoder_model_prefix , args .model_format , device_id = args . device_id
264+ args .model_dir , args .text_encoder_model_prefix , args .model_format , device_id = device_id
233265 )
234266 vae_decoder_runtime = create_paddle_inference_runtime (
235267 args .model_dir ,
236268 args .vae_decoder_model_prefix ,
237269 use_trt ,
238270 vae_decoder_dynamic_shape ,
239271 use_fp16 = args .use_fp16 ,
240- device_id = args . device_id ,
272+ device_id = device_id ,
241273 )
242274 vae_encoder_runtime = create_paddle_inference_runtime (
243275 args .model_dir ,
244276 args .vae_encoder_model_prefix ,
245277 use_trt ,
246278 vae_encoder_dynamic_shape ,
247279 use_fp16 = args .use_fp16 ,
248- device_id = args . device_id ,
280+ device_id = device_id ,
249281 )
250282 start = time .time ()
251283 unet_runtime = create_paddle_inference_runtime (
@@ -254,7 +286,7 @@ def get_scheduler(args):
254286 use_trt ,
255287 unet_dynamic_shape ,
256288 use_fp16 = args .use_fp16 ,
257- device_id = args . device_id ,
289+ device_id = device_id ,
258290 )
259291 print (f"Spend { time .time () - start : .2f} s to load unet model." )
260292 elif args .backend == "tensorrt" :
@@ -265,23 +297,38 @@ def get_scheduler(args):
265297 args .model_format ,
266298 workspace = (1 << 30 ),
267299 dynamic_shape = vae_decoder_dynamic_shape ,
268- device_id = args . device_id ,
300+ device_id = device_id ,
269301 )
270302 vae_encoder_runtime = create_trt_runtime (
271303 args .model_dir ,
272304 args .vae_encoder_model_prefix ,
273305 args .model_format ,
274306 workspace = (1 << 30 ),
275307 dynamic_shape = vae_encoder_dynamic_shape ,
276- device_id = args . device_id ,
308+ device_id = device_id ,
277309 )
278310 start = time .time ()
279311 unet_runtime = create_trt_runtime (
280312 args .model_dir ,
281313 args .unet_model_prefix ,
282314 args .model_format ,
283315 dynamic_shape = unet_dynamic_shape ,
284- device_id = args .device_id ,
316+ device_id = device_id ,
317+ )
318+ print (f"Spend { time .time () - start : .2f} s to load unet model." )
319+ elif args .backend == "paddlelite" :
320+ text_encoder_runtime = create_paddle_lite_runtime (
321+ args .model_dir , args .text_encoder_model_prefix , device = args .device , device_id = device_id
322+ )
323+ vae_decoder_runtime = create_paddle_lite_runtime (
324+ args .model_dir , args .vae_decoder_model_prefix , device = args .device , device_id = device_id
325+ )
326+ vae_encoder_runtime = create_paddle_lite_runtime (
327+ args .model_dir , args .vae_encoder_model_prefix , device = args .device , device_id = device_id
328+ )
329+ start = time .time ()
330+ unet_runtime = create_paddle_lite_runtime (
331+ args .model_dir , args .unet_model_prefix , device = args .device , device_id = device_id
285332 )
286333 print (f"Spend { time .time () - start : .2f} s to load unet model." )
287334
0 commit comments