Skip to content

Commit 4c9d12d

Browse files
zhangbo9674zmxdream
authored andcommitted
Fix multi tensor momentum regular bug (PaddlePaddle#38344)
* fix merged_momentum regular bug * fix bug
1 parent 41c0f48 commit 4c9d12d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/paddle/optimizer/momentum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def __init__(self,
192192

193193
def _update_regularization(self, weight_decay):
194194
reg_method = ""
195-
reg_coeff = 0
195+
reg_coeff = 0.0
196196

197197
if (isinstance(weight_decay, L2DecayRegularizer)):
198198
reg_method = "l2_decay"
@@ -306,7 +306,7 @@ def _append_optimize_op(self, block, param_and_grad):
306306
# the param's regularization has been done before, we avoid do l2decay in momentum.
307307
elif param.regularizer is not None:
308308
regularization_method = ""
309-
regularization_coeff = 0
309+
regularization_coeff = 0.0
310310

311311
find_master = self._multi_precision and param_and_grad[
312312
0].dtype == core.VarDesc.VarType.FP16
@@ -380,7 +380,7 @@ def _multi_tensor_init(self, target_block, parameters):
380380
if isinstance(param.regularizer, L2DecayRegularizer):
381381
regularization_method = "l2_decay"
382382
regularization_coeff = param.regularizer._regularization_coeff
383-
else:
383+
elif param.regularizer is not None:
384384
regularization_method = ""
385385
regularization_coeff = 0.0
386386
if param.dtype == paddle.float32:

0 commit comments

Comments
 (0)