2525class TestMultiplexOp (OpTest ):
2626 def setUp (self ):
2727 self .op_type = "multiplex"
28- self .python_api = paddle .multiplex
2928 rows = 4
3029 index = np .arange (0 , rows ).astype ('int32' )
3130 np .random .shuffle (index )
@@ -46,25 +45,19 @@ def setUp(self):
4645 self .outputs = {'Out' : output }
4746
4847 def test_check_output (self ):
49- self .check_output (check_eager = True )
48+ self .check_output ()
5049
5150 def test_check_grad (self ):
52- self .check_grad (['x1' , 'x2' , 'x3' , 'x4' ], 'Out' , check_eager = True )
51+ self .check_grad (['x1' , 'x2' , 'x3' , 'x4' ], 'Out' )
5352
5453 def test_check_grad_ignore_x1 (self ):
55- self .check_grad (
56- ['x2' , 'x3' , 'x4' ], 'Out' , no_grad_set = set ('x1' ), check_eager = True )
54+ self .check_grad (['x2' , 'x3' , 'x4' ], 'Out' , no_grad_set = set ('x1' ))
5755
5856 def test_check_grad_ignore_x1_x2 (self ):
59- self .check_grad (
60- ['x3' , 'x4' ],
61- 'Out' ,
62- no_grad_set = set (['x1' , 'x2' ]),
63- check_eager = True )
57+ self .check_grad (['x3' , 'x4' ], 'Out' , no_grad_set = set (['x1' , 'x2' ]))
6458
6559 def test_check_grad_ignore_x3 (self ):
66- self .check_grad (
67- ['x1' , 'x2' , 'x4' ], 'Out' , no_grad_set = set ('x3' ), check_eager = True )
60+ self .check_grad (['x1' , 'x2' , 'x4' ], 'Out' , no_grad_set = set ('x3' ))
6861
6962
7063class TestMultiplexOpError (unittest .TestCase ):
@@ -111,8 +104,28 @@ def test_multiplex_dygraph(self):
111104 paddle .enable_static ()
112105
113106 def test_dygraph_final_state_api (self ):
114- with _test_eager_guard ():
115- self .test_multiplex_dygraph ()
107+ with fluid .dygraph .guard ():
108+ img1 = np .array ([[1 , 2 ], [3 , 4 ]]).astype (np .float32 )
109+ img2 = np .array ([[5 , 6 ], [7 , 8 ]]).astype (np .float32 )
110+ inputs = [paddle .to_tensor (img1 ), paddle .to_tensor (img2 )]
111+ index = paddle .to_tensor (np .array ([[1 ], [0 ]]).astype (np .int32 ))
112+ inputs [0 ].stop_gradient = False
113+ inputs [1 ].stop_gradient = False
114+ res = paddle .multiplex (inputs , index )
115+ res .backward ()
116+ with _test_eager_guard ():
117+ inputs_eager = [paddle .to_tensor (img1 ), paddle .to_tensor (img2 )]
118+ index_eager = paddle .to_tensor (
119+ np .array ([[1 ], [0 ]]).astype (np .int32 ))
120+ inputs_eager [0 ].stop_gradient = False
121+ inputs_eager [1 ].stop_gradient = False
122+ res_eager = paddle .multiplex (inputs_eager , index_eager )
123+ res_eager .backward ()
124+ self .assertEqual ((res .numpy () == res_eager .numpy ()).all (), True )
125+ self .assertEqual ((inputs [0 ].grad .numpy () ==
126+ inputs_eager [0 ].grad .numpy ()).all (), True )
127+ self .assertEqual ((inputs [1 ].grad .numpy () ==
128+ inputs_eager [1 ].grad .numpy ()).all (), True )
116129
117130
118131if __name__ == '__main__' :
0 commit comments