File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed
python/paddle/distributed/passes Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -332,6 +332,18 @@ def _pir_append_gradient_merge_backward_op(
332332 grad_defining_op .dist_attr .chunk_id ,
333333 )
334334 )
335+ # NOTE(zhangweilong): grad may in different device in auto_parallel, so need consider all_gather op
336+ for used_grad_op in grad .all_used_ops ():
337+ if used_grad_op .name () != "pd_op.all_gather" :
338+ continue
339+ move_to_opt_block_flag = True
340+ for all_gather_result in used_grad_op .results ():
341+ for used_op in all_gather_result .all_used_ops ():
342+ if used_op .op_role != int (OpRole .Optimize ):
343+ move_to_opt_block_flag = False
344+ break
345+ if move_to_opt_block_flag :
346+ used_grad_op .op_role = int (OpRole .Optimize )
335347
336348 opt_ops_use_grad = [
337349 op
You can’t perform that action at this time.
0 commit comments