Skip to content

Commit 32dae48

Browse files
add unitest for reshpe 0 dims (#53685)
1 parent 4a97ba5 commit 32dae48

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

test/ir/inference/test_trt_convert_reshape.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,5 +431,99 @@ def test(self):
431431
self.run_test()
432432

433433

434+
class TrtConvertReshapeZeroDimsTest(TrtLayerAutoScanTest):
435+
def is_program_valid(self, program_config: ProgramConfig) -> bool:
436+
return True
437+
438+
def sample_program_configs(self):
439+
def generate_input1(attrs: List[Dict[str, Any]]):
440+
if self.dims > 0:
441+
self.input_shape = [1] * self.dims
442+
return np.random.random(self.input_shape).astype(np.float32)
443+
elif self.dims == 0:
444+
self.input_shape = []
445+
return np.random.random([]).astype(np.float32)
446+
447+
for dims in [0, 1, 2, 3]:
448+
for shape in [
449+
[],
450+
[1, 1],
451+
]:
452+
dics = [
453+
{
454+
"shape": shape,
455+
},
456+
]
457+
self.dims = dims
458+
dics_intput = [{"X": ["reshape_input"]}]
459+
460+
ops_config = [
461+
{
462+
"op_type": "reshape",
463+
"op_inputs": dics_intput[0],
464+
"op_outputs": {"Out": ["reshape_out"]},
465+
"op_attrs": dics[0],
466+
}
467+
]
468+
ops = self.generate_op_config(ops_config)
469+
program_config = ProgramConfig(
470+
ops=ops,
471+
weights={},
472+
inputs={
473+
"reshape_input": TensorConfig(
474+
data_gen=partial(generate_input1, dics)
475+
)
476+
},
477+
outputs=["reshape_out"],
478+
)
479+
480+
yield program_config
481+
482+
def sample_predictor_configs(
483+
self, program_config
484+
) -> (paddle_infer.Config, List[int], float):
485+
def generate_dynamic_shape(attrs):
486+
self.dynamic_shape.min_input_shape = {
487+
"reshape_input": self.input_shape
488+
}
489+
self.dynamic_shape.max_input_shape = {
490+
"reshape_input": self.input_shape
491+
}
492+
self.dynamic_shape.opt_input_shape = {
493+
"reshape_input": self.input_shape
494+
}
495+
496+
def clear_dynamic_shape():
497+
self.dynamic_shape.min_input_shape = {}
498+
self.dynamic_shape.max_input_shape = {}
499+
self.dynamic_shape.opt_input_shape = {}
500+
501+
def generate_trt_nodes_num(attrs, dynamic_shape):
502+
# only test dynamic shape mode
503+
return 1, 2
504+
505+
attrs = [
506+
program_config.ops[i].attrs for i in range(len(program_config.ops))
507+
]
508+
509+
# for dynamic_shape
510+
generate_dynamic_shape(attrs)
511+
self.trt_param.precision = paddle_infer.PrecisionType.Float32
512+
yield self.create_inference_config(), generate_trt_nodes_num(
513+
attrs, True
514+
), 1e-5
515+
self.trt_param.precision = paddle_infer.PrecisionType.Half
516+
yield self.create_inference_config(), generate_trt_nodes_num(
517+
attrs, True
518+
), 1e-3
519+
520+
def add_skip_trt_case(self):
521+
pass
522+
523+
def test(self):
524+
self.add_skip_trt_case()
525+
self.run_test()
526+
527+
434528
if __name__ == "__main__":
435529
unittest.main()

0 commit comments

Comments
 (0)