@@ -431,5 +431,99 @@ def test(self):
431431 self .run_test ()
432432
433433
434+ class TrtConvertReshapeZeroDimsTest (TrtLayerAutoScanTest ):
435+ def is_program_valid (self , program_config : ProgramConfig ) -> bool :
436+ return True
437+
438+ def sample_program_configs (self ):
439+ def generate_input1 (attrs : List [Dict [str , Any ]]):
440+ if self .dims > 0 :
441+ self .input_shape = [1 ] * self .dims
442+ return np .random .random (self .input_shape ).astype (np .float32 )
443+ elif self .dims == 0 :
444+ self .input_shape = []
445+ return np .random .random ([]).astype (np .float32 )
446+
447+ for dims in [0 , 1 , 2 , 3 ]:
448+ for shape in [
449+ [],
450+ [1 , 1 ],
451+ ]:
452+ dics = [
453+ {
454+ "shape" : shape ,
455+ },
456+ ]
457+ self .dims = dims
458+ dics_intput = [{"X" : ["reshape_input" ]}]
459+
460+ ops_config = [
461+ {
462+ "op_type" : "reshape" ,
463+ "op_inputs" : dics_intput [0 ],
464+ "op_outputs" : {"Out" : ["reshape_out" ]},
465+ "op_attrs" : dics [0 ],
466+ }
467+ ]
468+ ops = self .generate_op_config (ops_config )
469+ program_config = ProgramConfig (
470+ ops = ops ,
471+ weights = {},
472+ inputs = {
473+ "reshape_input" : TensorConfig (
474+ data_gen = partial (generate_input1 , dics )
475+ )
476+ },
477+ outputs = ["reshape_out" ],
478+ )
479+
480+ yield program_config
481+
482+ def sample_predictor_configs (
483+ self , program_config
484+ ) -> (paddle_infer .Config , List [int ], float ):
485+ def generate_dynamic_shape (attrs ):
486+ self .dynamic_shape .min_input_shape = {
487+ "reshape_input" : self .input_shape
488+ }
489+ self .dynamic_shape .max_input_shape = {
490+ "reshape_input" : self .input_shape
491+ }
492+ self .dynamic_shape .opt_input_shape = {
493+ "reshape_input" : self .input_shape
494+ }
495+
496+ def clear_dynamic_shape ():
497+ self .dynamic_shape .min_input_shape = {}
498+ self .dynamic_shape .max_input_shape = {}
499+ self .dynamic_shape .opt_input_shape = {}
500+
501+ def generate_trt_nodes_num (attrs , dynamic_shape ):
502+ # only test dynamic shape mode
503+ return 1 , 2
504+
505+ attrs = [
506+ program_config .ops [i ].attrs for i in range (len (program_config .ops ))
507+ ]
508+
509+ # for dynamic_shape
510+ generate_dynamic_shape (attrs )
511+ self .trt_param .precision = paddle_infer .PrecisionType .Float32
512+ yield self .create_inference_config (), generate_trt_nodes_num (
513+ attrs , True
514+ ), 1e-5
515+ self .trt_param .precision = paddle_infer .PrecisionType .Half
516+ yield self .create_inference_config (), generate_trt_nodes_num (
517+ attrs , True
518+ ), 1e-3
519+
520+ def add_skip_trt_case (self ):
521+ pass
522+
523+ def test (self ):
524+ self .add_skip_trt_case ()
525+ self .run_test ()
526+
527+
434528if __name__ == "__main__" :
435529 unittest .main ()
0 commit comments