Skip to content

Commit f8a219f

Browse files
authored
[Paddle TensorRT] Modify the base class for enabling FP16 unit tests (#70139)
* fix * fix * fix
1 parent 5c2ded5 commit f8a219f

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

test/tensorrt/tensorrt_test_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __init__(self, methodName='runTest'):
4242
self.max_shape = None
4343
self.target_marker_op = ""
4444
self.dynamic_shape_data = {}
45-
self.enable_fp16 = None
4645

4746
def create_fake_program(self):
4847
if self.python_api is None:
@@ -155,7 +154,7 @@ def prepare_feed(self):
155154
new_list_args[sub_arg_name] = self.api_args[arg_name][i]
156155
self.api_args[arg_name] = new_list_args
157156

158-
def check_trt_result(self, rtol=1e-5, atol=1e-5):
157+
def check_trt_result(self, rtol=1e-5, atol=1e-5, precision_mode="fp32"):
159158
paddle.framework.set_flags({"FLAGS_trt_min_group_size": 1})
160159
with paddle.pir_utils.IrGuard():
161160
self.prepare_feed()
@@ -255,7 +254,7 @@ def check_trt_result(self, rtol=1e-5, atol=1e-5):
255254

256255
# run TRTConverter(would lower group_op into tensorrt_engine_op)
257256
trt_config = None
258-
if self.enable_fp16:
257+
if precision_mode == "fp16":
259258
input = Input(
260259
min_input_shape=self.min_shape,
261260
optim_input_shape=self.min_shape,

test/tensorrt/test_converter_conv.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ def setUp(self):
4040
self.program_config = {"feed_list": ["x"]}
4141
self.min_shape = {"x": [1, 3, 8, 8]}
4242
self.max_shape = {"x": [10, 3, 8, 8]}
43-
self.enable_fp16 = True
4443

45-
def test_trt_result(self):
44+
def test_trt_result_fp16(self):
45+
self.check_trt_result(precision_mode="fp16")
46+
47+
def test_trt_result_fp32(self):
4648
self.check_trt_result()
4749

4850

0 commit comments

Comments
 (0)