Skip to content

Commit 9fee57d

Browse files
authored
[AutoParallel] use inplace multiply in grad_clip (PaddlePaddle#71421)
* use inplace multiply in grad_clip * use inplace multiply in grad_clip
1 parent 5911a3a commit 9fee57d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

python/paddle/nn/clip.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,12 @@ def async_add_n(var_list):
840840
clip_input = paddle.distributed.reshard(
841841
clip_input, g.process_mesh, clip_input.placements
842842
)
843-
new_grad = paddle.multiply(g, clip_input)
844-
params_and_grads.append((p, new_grad))
843+
if g.is_dist() or g.is_dense():
844+
g.multiply_(clip_input)
845+
params_and_grads.append((p, g))
846+
else:
847+
new_grad = paddle.multiply(g, clip_input)
848+
params_and_grads.append((p, new_grad))
845849
else:
846850
params_and_grads.append((p, g))
847851

0 commit comments

Comments
 (0)