@@ -1398,7 +1398,7 @@ def compare_single_output_with_expect(self, name, expect):
13981398 # NOTE(zhiqiu): np.allclose([], [1.]) returns True
13991399 # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
14001400 if expect_np .size == 0 :
1401- self .op_test .assertTrue (actual_np .size == 0 ) # }}}
1401+ self .op_test .assertTrue (actual_np .size == 0 )
14021402 self ._compare_numpy (name , actual_np , expect_np )
14031403 if isinstance (expect , tuple ):
14041404 self ._compare_list (name , actual , expect )
@@ -1486,7 +1486,7 @@ def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
14861486 if actual_np .dtype == np .uint16 :
14871487 actual_np = convert_uint16_to_float (actual_np )
14881488 if expect_np .dtype == np .uint16 :
1489- expect_np = convert_uint16_to_float (expect_np ) # }}}
1489+ expect_np = convert_uint16_to_float (expect_np )
14901490 return actual_np , expect_np
14911491
14921492 def _compare_list (self , name , actual , expect ):
@@ -1519,11 +1519,13 @@ def _compare_numpy(self, name, actual_np, expect_np):
15191519 class EagerChecker (DygraphChecker ):
15201520 def calculate_output (self ):
15211521 # we only check end2end api when check_eager=True
1522+ self .is_python_api_test = True
15221523 with _test_eager_guard ():
15231524 eager_dygraph_outs = self .op_test ._calc_python_api_output (
15241525 place )
15251526 if eager_dygraph_outs is None :
15261527 # missing KernelSignature, fall back to eager middle output.
1528+ self .is_python_api_test = False
15271529 eager_dygraph_outs = self .op_test ._calc_dygraph_output (
15281530 place , no_check_set = no_check_set )
15291531 self .outputs = eager_dygraph_outs
@@ -1547,9 +1549,16 @@ def _compare_list(self, name, actual, expect):
15471549 with _test_eager_guard ():
15481550 super ()._compare_list (name , actual , expect )
15491551
1550- # set some flags by the combination of arguments.
1552+ def _is_skip_name (self , name ):
1553+ # if in final state and kernel signature don't have name, then skip it.
1554+ if self .is_python_api_test and hasattr (
1555+ self .op_test , "python_out_sig"
1556+ ) and name not in self .op_test .python_out_sig :
1557+ return True
1558+ return super ()._is_skip_name (name )
15511559
1552- self .infer_dtype_from_inputs_outputs (self .inputs , self .outputs ) # {{{
1560+ # set some flags by the combination of arguments.
1561+ self .infer_dtype_from_inputs_outputs (self .inputs , self .outputs )
15531562 if self .dtype == np .float64 and \
15541563 self .op_type not in op_threshold_white_list .NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST :
15551564 atol = 0
@@ -1569,8 +1578,7 @@ def _compare_list(self, name, actual, expect):
15691578 if no_check_set is not None :
15701579 if self .op_type not in no_check_set_white_list .no_check_set_white_list :
15711580 raise AssertionError (
1572- "no_check_set of op %s must be set to None." %
1573- self .op_type ) # }}}
1581+ "no_check_set of op %s must be set to None." % self .op_type )
15741582 static_checker = StaticChecker (self , self .outputs )
15751583 static_checker .check ()
15761584 outs , fetch_list = static_checker .outputs , static_checker .fetch_list
@@ -1610,8 +1618,6 @@ def _compare_list(self, name, actual, expect):
16101618 else :
16111619 return outs , fetch_list
16121620
1613- # }}}
1614-
16151621 def check_compile_vs_runtime (self , fetch_list , fetch_outs ):
16161622 def find_fetch_index (target_name , fetch_list ):
16171623 found = [
0 commit comments