Skip to content

Commit ab8d6b6

Browse files
[AutoParallel] fix grade_merge bug (#68664)
1 parent f56e672 commit ab8d6b6

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

python/paddle/distributed/passes/auto_parallel_gradient_merge.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,18 @@ def _pir_append_gradient_merge_backward_op(
332332
grad_defining_op.dist_attr.chunk_id,
333333
)
334334
)
335+
# NOTE(zhangweilong): grad may in different device in auto_parallel, so need consider all_gather op
336+
for used_grad_op in grad.all_used_ops():
337+
if used_grad_op.name() != "pd_op.all_gather":
338+
continue
339+
move_to_opt_block_flag = True
340+
for all_gather_result in used_grad_op.results():
341+
for used_op in all_gather_result.all_used_ops():
342+
if used_op.op_role != int(OpRole.Optimize):
343+
move_to_opt_block_flag = False
344+
break
345+
if move_to_opt_block_flag:
346+
used_grad_op.op_role = int(OpRole.Optimize)
335347

336348
opt_ops_use_grad = [
337349
op

0 commit comments

Comments
 (0)