Skip to content

Commit 04325d2

Browse files
authored
Optest refactor (#40998)
* first version, maybe many errors * refactor op_test * fix compare list * fix bg * fix bugs * skip name
1 parent 45078d9 commit 04325d2

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,7 +1398,7 @@ def compare_single_output_with_expect(self, name, expect):
13981398
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
13991399
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
14001400
if expect_np.size == 0:
1401-
self.op_test.assertTrue(actual_np.size == 0) # }}}
1401+
self.op_test.assertTrue(actual_np.size == 0)
14021402
self._compare_numpy(name, actual_np, expect_np)
14031403
if isinstance(expect, tuple):
14041404
self._compare_list(name, actual, expect)
@@ -1486,7 +1486,7 @@ def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
14861486
if actual_np.dtype == np.uint16:
14871487
actual_np = convert_uint16_to_float(actual_np)
14881488
if expect_np.dtype == np.uint16:
1489-
expect_np = convert_uint16_to_float(expect_np) # }}}
1489+
expect_np = convert_uint16_to_float(expect_np)
14901490
return actual_np, expect_np
14911491

14921492
def _compare_list(self, name, actual, expect):
@@ -1519,11 +1519,13 @@ def _compare_numpy(self, name, actual_np, expect_np):
15191519
class EagerChecker(DygraphChecker):
15201520
def calculate_output(self):
15211521
# we only check end2end api when check_eager=True
1522+
self.is_python_api_test = True
15221523
with _test_eager_guard():
15231524
eager_dygraph_outs = self.op_test._calc_python_api_output(
15241525
place)
15251526
if eager_dygraph_outs is None:
15261527
# missing KernelSignature, fall back to eager middle output.
1528+
self.is_python_api_test = False
15271529
eager_dygraph_outs = self.op_test._calc_dygraph_output(
15281530
place, no_check_set=no_check_set)
15291531
self.outputs = eager_dygraph_outs
@@ -1547,9 +1549,16 @@ def _compare_list(self, name, actual, expect):
15471549
with _test_eager_guard():
15481550
super()._compare_list(name, actual, expect)
15491551

1550-
# set some flags by the combination of arguments.
1552+
def _is_skip_name(self, name):
1553+
# if in final state and kernel signature don't have name, then skip it.
1554+
if self.is_python_api_test and hasattr(
1555+
self.op_test, "python_out_sig"
1556+
) and name not in self.op_test.python_out_sig:
1557+
return True
1558+
return super()._is_skip_name(name)
15511559

1552-
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) # {{{
1560+
# set some flags by the combination of arguments.
1561+
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
15531562
if self.dtype == np.float64 and \
15541563
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
15551564
atol = 0
@@ -1569,8 +1578,7 @@ def _compare_list(self, name, actual, expect):
15691578
if no_check_set is not None:
15701579
if self.op_type not in no_check_set_white_list.no_check_set_white_list:
15711580
raise AssertionError(
1572-
"no_check_set of op %s must be set to None." %
1573-
self.op_type) # }}}
1581+
"no_check_set of op %s must be set to None." % self.op_type)
15741582
static_checker = StaticChecker(self, self.outputs)
15751583
static_checker.check()
15761584
outs, fetch_list = static_checker.outputs, static_checker.fetch_list
@@ -1610,8 +1618,6 @@ def _compare_list(self, name, actual, expect):
16101618
else:
16111619
return outs, fetch_list
16121620

1613-
# }}}
1614-
16151621
def check_compile_vs_runtime(self, fetch_list, fetch_outs):
16161622
def find_fetch_index(target_name, fetch_list):
16171623
found = [

0 commit comments

Comments
 (0)