1818import numpy as np
1919import paddle
2020import paddle .static as static
21+ from paddle .fluid .framework import _test_eager_guard
2122
2223p_list_n_n = ("fro" , "nuc" , 1 , - 1 , np .inf , - np .inf )
2324p_list_m_n = (None , 2 , - 2 )
@@ -89,16 +90,21 @@ def test_out(self):
8990
9091
9192class API_TestDygraphCond (unittest .TestCase ):
92- def test_out (self ):
93+ def func_out (self ):
9394 paddle .disable_static ()
9495 # test calling results of 'cond' in dynamic mode
9596 x_list_n_n , x_list_m_n = gen_input ()
9697 test_dygraph_assert_true (self , x_list_n_n , p_list_n_n + p_list_m_n )
9798 test_dygraph_assert_true (self , x_list_m_n , p_list_m_n )
9899
100+ def test_out (self ):
101+ with _test_eager_guard ():
102+ self .func_out ()
103+ self .func_out ()
104+
99105
100106class TestCondAPIError (unittest .TestCase ):
101- def test_dygraph_api_error (self ):
107+ def func_dygraph_api_error (self ):
102108 paddle .disable_static ()
103109 # test raising errors when 'cond' is called in dygraph mode
104110 p_list_error = ('fro_' , '_nuc' , - 0.7 , 0 , 1.5 , 3 )
@@ -113,6 +119,11 @@ def test_dygraph_api_error(self):
113119 x_tensor = paddle .to_tensor (x )
114120 self .assertRaises (ValueError , paddle .linalg .cond , x_tensor , p )
115121
122+ def test_dygraph_api_error (self ):
123+ with _test_eager_guard ():
124+ self .func_dygraph_api_error ()
125+ self .func_dygraph_api_error ()
126+
116127 def test_static_api_error (self ):
117128 paddle .enable_static ()
118129 # test raising errors when 'cond' is called in static mode
@@ -149,13 +160,18 @@ def test_static_empty_input_error(self):
149160
150161
151162class TestCondEmptyTensorInput (unittest .TestCase ):
152- def test_dygraph_empty_tensor_input (self ):
163+ def func_dygraph_empty_tensor_input (self ):
153164 paddle .disable_static ()
154165 # test calling results of 'cond' when input is an empty tensor in dynamic mode
155166 x_list_n_n , x_list_m_n = gen_empty_input ()
156167 test_dygraph_assert_true (self , x_list_n_n , p_list_n_n + p_list_m_n )
157168 test_dygraph_assert_true (self , x_list_m_n , p_list_m_n )
158169
170+ def test_dygraph_empty_tensor_input (self ):
171+ with _test_eager_guard ():
172+ self .func_dygraph_empty_tensor_input ()
173+ self .func_dygraph_empty_tensor_input ()
174+
159175
160176if __name__ == "__main__" :
161177 paddle .enable_static ()
0 commit comments