File tree Expand file tree Collapse file tree 2 files changed +13
-6
lines changed Expand file tree Collapse file tree 2 files changed +13
-6
lines changed Original file line number Diff line number Diff line change @@ -120,11 +120,12 @@ def get_parameter(self, name):
120120 raise ValueError ("no Parameter name %s found" % name )
121121 return param
122122
123- def create_tmp_variable (self , dtype ):
123+ def create_tmp_variable (self , dtype , stop_gradient = False ):
124124 return self .main_program .current_block ().create_var (
125125 name = unique_name ("." .join ([self .name , 'tmp' ])),
126126 dtype = dtype ,
127- persistable = False )
127+ persistable = False ,
128+ stop_gradient = stop_gradient )
128129
129130 def create_variable (self , * args , ** kwargs ):
130131 return self .main_program .current_block ().create_var (* args , ** kwargs )
Original file line number Diff line number Diff line change @@ -971,20 +971,26 @@ def batch_norm(input,
971971 attr = helper .param_attr , shape = param_shape , dtype = dtype , is_bias = True )
972972
973973 mean = helper .create_global_variable (
974- dtype = input .dtype , shape = param_shape , persistable = True )
974+ dtype = input .dtype ,
975+ shape = param_shape ,
976+ persistable = True ,
977+ stop_gradient = True )
975978 helper .set_variable_initializer (var = mean , initializer = Constant (0.0 ))
976979
977980 variance = helper .create_global_variable (
978- dtype = input .dtype , shape = param_shape , persistable = True )
981+ dtype = input .dtype ,
982+ shape = param_shape ,
983+ persistable = True ,
984+ stop_gradient = True )
979985 helper .set_variable_initializer (var = variance , initializer = Constant (1.0 ))
980986
981987 # create output
982988 # mean and mean_out share the same memory
983989 mean_out = mean
984990 # variance and variance out share the same memory
985991 variance_out = variance
986- saved_mean = helper .create_tmp_variable (dtype )
987- saved_variance = helper .create_tmp_variable (dtype )
992+ saved_mean = helper .create_tmp_variable (dtype = dtype , stop_gradient = True )
993+ saved_variance = helper .create_tmp_variable (dtype = dtype , stop_gradient = True )
988994
989995 batch_norm_out = helper .create_tmp_variable (dtype )
990996
You can’t perform that action at this time.
0 commit comments