@@ -38,14 +38,18 @@ def __init__(self, block):
3838        self ._op_fp16_dict  =  {
3939        }  # op_id --> True/False. 'True' means that the current op is in fp16 mode. 
4040        self ._var_name_dict  =  {}  # fwd_op_id --> {old_name: cast_name} 
41+         self .is_train  =  False 
4142
4243    def  _is_fp16_op (self , op_id ):
4344        return  self ._op_fp16_dict .get (op_id , None )
4445
45-     def  _build_stats (self , amp_lists , dist_context ):
46+     def  _build_state (self , amp_lists , dist_context ):
4647        ops  =  self ._block .ops 
4748        dist_op_context  =  dist_context .dist_op_context 
4849        for  op  in  ops :
50+             if  int (op .attr ('op_role' )) ==  257 :
51+                 self .is_train  =  True 
52+ 
4953            if  int (op .attr ('op_role' )) ==  int (OpRole .Forward ):
5054                self ._mark_black_white_ops (amp_lists )
5155            elif  int (op .attr ('op_role' )) ==  int (OpRole .Backward ):
@@ -59,6 +63,8 @@ def _build_stats(self, amp_lists, dist_context):
5963            elif  int (op .attr ('op_role' )) ==  int (OpRole .Optimize ):
6064                break 
6165
66+         return  self .is_train 
67+ 
6268    def  _mark_black_white_ops (self , amp_lists ):
6369        """ 
6470        this function is modified from paddle.fluid.contrib.mixed_precision 
@@ -546,23 +552,25 @@ def _apply_single_impl(self, main_program, startup_program, context):
546552            set (self .get_attr ("custom_black_list" )),
547553            set (self .get_attr ("custom_black_varnames" )))
548554
549-         amp_state  =  AMPState (main_program .global_block ())
550-         amp_state ._build_stats (amp_lists , self .dist_context )
551- 
552555        with  paddle .static .program_guard (main_program , startup_program ):
556+             amp_state  =  AMPState (main_program .global_block ())
557+             is_train  =  amp_state ._build_state (amp_lists , self .dist_context )
558+ 
553559            amp_state .cast_forward_program (self .dist_context )
554-             amp_state .cast_backward_program (params_grads , self .dist_context )
555-             # TODO (JZ-LIANG)support cast forward program only when inference 
556-             self ._init_amp_var ()
557-             self ._scale_loss ()
558- 
559-             if  self .get_attr ("use_dynamic_loss_scaling" 
560-                              ) or  self .get_attr ("init_loss_scaling" ) !=  1.0 :
561-                 grads , found_inf  =  _check_and_update_gradient (
562-                     params_grads , self ._loss_scaling , self .dist_context )
563- 
564-             if  self .get_attr ("use_dynamic_loss_scaling" ):
565-                 self ._update_loss_scaling (grads , found_inf )
560+ 
561+         if  is_train :
562+             with  paddle .static .program_guard (main_program , startup_program ):
563+                 amp_state .cast_backward_program (params_grads , self .dist_context )
564+                 self ._init_amp_var ()
565+                 self ._scale_loss ()
566+ 
567+                 if  self .get_attr ("use_dynamic_loss_scaling" 
568+                                  ) or  self .get_attr ("init_loss_scaling" ) !=  1.0 :
569+                     grads , found_inf  =  _check_and_update_gradient (
570+                         params_grads , self ._loss_scaling , self .dist_context )
571+ 
572+                 if  self .get_attr ("use_dynamic_loss_scaling" ):
573+                     self ._update_loss_scaling (grads , found_inf )
566574
567575    def  _init_amp_var (self ):
568576        self ._loss_scaling  =  paddle .static .create_global_var (
0 commit comments