Skip to content

Commit dc04bb7

Browse files
authored
[Auto Parallel] do not fold global to sub reshard (#70023)
1 parent 2fc83be commit dc04bb7

File tree

1 file changed

+13
-5
lines changed
  • python/paddle/distributed/auto_parallel/static

1 file changed

+13
-5
lines changed

python/paddle/distributed/auto_parallel/static/pir_pass.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,19 @@ def reshard_single_value(program, op, operand, attr):
6363
# fold reshard
6464
if prev_var.get_defining_op().name() == 'dist_op.reshard':
6565
prev_reshard = prev_var.get_defining_op()
66-
prev_var = prev_reshard.operand_source(0)
67-
if prev_var.dist_attr() == operand_attr:
68-
return prev_var
69-
reshard_var = paddle._C_ops.reshard_v2(prev_var, operand_attr)
70-
return reshard_var
66+
prev_reshard_input = prev_reshard.operand_source(0)
67+
prev_reshard_result = prev_reshard.result(0)
68+
# skil global to sub mesh reshard
69+
if (
70+
prev_reshard_input.dist_attr().process_mesh.ndim
71+
== prev_reshard_result.dist_attr().process_mesh.ndim
72+
):
73+
if prev_reshard_input.dist_attr() == operand_attr:
74+
return prev_reshard_input
75+
reshard_var = paddle._C_ops.reshard_v2(
76+
prev_reshard_input, operand_attr
77+
)
78+
return reshard_var
7179
# insert reshard
7280
reshard_var = paddle._C_ops.reshard_v2(prev_var, operand_attr)
7381
return reshard_var

0 commit comments

Comments
 (0)