Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def _parallel_pir(self, mode):
# TODO(hitywt) Step 3.2: Reshard Pass
# resolute the reshard op into special collective operation.
# collect the communicator created during resolution.
ReshardPasses.apply_reshard_pass(dist_program)
ReshardPasses.apply_reshard_pass(dist_program, global_params_grads)

# Note(luchang): When using VPP pipeline pass, we need to split the whole graph into
# multiple chunks and adjust the process mesh accordingly. Here, we need to store the
Expand Down
18 changes: 15 additions & 3 deletions python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def fold_reshard_pass(dist_program):
op.erase()

@staticmethod
def reshard_op_pass(dist_program, block=None):
def reshard_op_pass(dist_program, global_params_grads=None, block=None):
if block is None:
block = dist_program.global_block()
for op in block.ops:
Expand All @@ -322,6 +322,10 @@ def reshard_op_pass(dist_program, block=None):

if src_dist_attr == dst_dist_attr:
op.result(0).replace_all_uses_with(var)
if global_params_grads is not None:
for idx, (p, g) in enumerate(global_params_grads):
if g is not None and g.is_same(op.result(0)):
global_params_grads[idx] = (p, var)
op.erase()
continue

Expand All @@ -345,13 +349,21 @@ def reshard_op_pass(dist_program, block=None):
op.result(0).replace_all_uses_with(out_value)

if op.result(0).use_empty():
if global_params_grads is not None:
for idx, (p, g) in enumerate(global_params_grads):
if g is not None and g.is_same(op.result(0)):
global_params_grads[idx] = (
(p, out_value)
if out_value is not None
else (p, var)
)
op.erase()

@staticmethod
def apply_reshard_pass(dist_program):
def apply_reshard_pass(dist_program, global_params_grads=None):
ReshardPasses.decompose_reshard_pass(dist_program)
ReshardPasses.fold_reshard_pass(dist_program)
ReshardPasses.reshard_op_pass(dist_program)
ReshardPasses.reshard_op_pass(dist_program, global_params_grads)


# Replace the specific MoE-related dist op with the
Expand Down