@@ -386,10 +386,12 @@ def apply_to_static(net, use_cinn):
386386
387387
388388class  PrimeNet (paddle .nn .Layer ):
389-     def  __init__ (self , data_layout = 'NCHW' ):
389+     def  __init__ (self , data_layout = 'NCHW' ,  is_test = False ):
390390        super ().__init__ ()
391391        self .conv  =  nn .Conv2D (2 , 4 , (3 , 3 ), bias_attr = False )
392-         self .bn  =  BatchNorm (4 , act = "relu" , data_layout = data_layout )
392+         self .bn  =  BatchNorm (
393+             4 , act = "relu" , data_layout = data_layout , is_test = is_test 
394+         )
393395
394396    def  forward (self , x ):
395397        y  =  self .conv (x )
@@ -408,10 +410,10 @@ def setUp(self):
408410        self .x  =  paddle .randn ([4 , 2 , 6 , 6 ], dtype = "float32" )
409411        self .x .stop_gradient  =  False 
410412
411-     def  train (self , use_prim , data_layout = "NCHW" ):
413+     def  train (self , use_prim , data_layout = "NCHW" ,  is_test = False ):
412414        core ._set_prim_all_enabled (use_prim )
413415        paddle .seed (2022 )
414-         net  =  PrimeNet (data_layout )
416+         net  =  PrimeNet (data_layout = data_layout ,  is_test = is_test )
415417        sgd  =  paddle .optimizer .SGD (
416418            learning_rate = 0.1 , parameters = net .parameters ()
417419        )
@@ -429,8 +431,19 @@ def train(self, use_prim, data_layout="NCHW"):
429431
430432    def  test_amp_nchw (self ):
431433        if  not  isinstance (framework ._current_expected_place (), core .CPUPlace ):
432-             expected  =  self .train (False )
433-             actual  =  self .train (True )
434+             expected  =  self .train (use_prim = False )
435+             actual  =  self .train (use_prim = True )
436+             np .testing .assert_allclose (
437+                 expected ,
438+                 actual ,
439+                 rtol = 1e-3 ,
440+                 atol = 1e-3 ,
441+             )
442+ 
443+     def  test_amp_nchw_eval (self ):
444+         if  not  isinstance (framework ._current_expected_place (), core .CPUPlace ):
445+             expected  =  self .train (use_prim = False , is_test = True )
446+             actual  =  self .train (use_prim = True , is_test = True )
434447            np .testing .assert_allclose (
435448                expected ,
436449                actual ,
@@ -449,6 +462,19 @@ def test_amp_nhwc(self):
449462                atol = 1e-3 ,
450463            )
451464
465+     def  test_amp_nhwc_eval (self ):
466+         if  not  isinstance (framework ._current_expected_place (), core .CPUPlace ):
467+             expected  =  self .train (
468+                 use_prim = False , data_layout = "NHWC" , is_test = True 
469+             )
470+             actual  =  self .train (use_prim = True , data_layout = "NHWC" , is_test = True )
471+             np .testing .assert_allclose (
472+                 expected ,
473+                 actual ,
474+                 rtol = 1e-3 ,
475+                 atol = 1e-3 ,
476+             )
477+ 
452478
453479class  TestPrimEvalBranch (unittest .TestCase ):
454480    """ 
0 commit comments