Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const std::unordered_set<std::string> LegacyOpList = {
LoadCombineOp::name(),
CConcatOp::name(),
CBroadcast_Op::name(),
CBroadcastOp::name(),
CSyncCommStream_Op::name(),
DistributedPushSparseOp::name(),
SendV2Op::name(),
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,7 @@
output : Tensor(out)
infer_meta :
func : CumScalarAxisInferMeta
spmd_rule : CumSumInferSpmdDynamic
kernel :
func : cumsum
data_type : x
Expand Down Expand Up @@ -3511,6 +3512,7 @@
output : Tensor(out)
infer_meta :
func : OneHotInferMeta
spmd_rule : OneHotInferSpmdDynamic
kernel :
func : one_hot
traits : paddle::dialect::ForwardOnlyTrait
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,8 +1232,8 @@ def _get_pir_program_and_executor(self, cached_data):
for job_type in cached_data.plan.job_types():
ir_program = cached_data.plan.ir_program(job_type)
value_map = pir.IrMapping()
program = ir_program.clone(value_map)
type_to_program[job_type] = program
program_tmp = ir_program.clone(value_map)
type_to_program[job_type] = program_tmp
value_map_list.append(value_map)

job_list = []
Expand Down
122 changes: 74 additions & 48 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def forward(
ctx,
local_tensor_list,
local_mesh_list,
local_placements,
idx,
global_dims,
mesh,
Expand All @@ -338,17 +339,15 @@ def forward(
if local_tensor.is_dist():
local_mesh = local_tensor.process_mesh
local_val = local_tensor._local_value()
local_placement = local_tensor.placements[0]
else:
local_val = local_tensor
local_mesh = None
local_placement = dist.Replicate()

ctx.global_mesh = copy.deepcopy(mesh)
ctx.placements = placements
ctx.local_dims = local_tensor.shape
ctx.local_mesh_list = copy.deepcopy(local_mesh_list)
ctx.local_placement = local_placement
ctx.local_placements = local_placements

place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
Expand All @@ -360,7 +359,7 @@ def forward(
placements=placements,
place=place,
)
global_tensor.stop_gradient = False
global_tensor.stop_gradient = local_tensor.stop_gradient
return global_tensor

@staticmethod
Expand All @@ -377,91 +376,111 @@ def backward(ctx, grad_tensor):
grad_tensor._local_value(),
dims=ctx.local_dims,
process_mesh=local_mesh,
placements=[ctx.local_placement],
placements=ctx.local_placements,
place=place,
)
)
out[-1].get_tensor()._unsafe_set_skip_check_mesh(True)
return out


def get_sub_meshes_from_global_mesh(
global_mesh, global_placements, local_mesh_dim
):
if (
global_mesh is not None
and local_mesh_dim is not None
and global_placements is not None
def split_mesh(global_mesh: dist.ProcessMesh, sub_mesh_dim: int):
mesh_shape = global_mesh.shape
mesh_ndim = len(mesh_shape)
if sub_mesh_dim >= mesh_ndim or (
sub_mesh_dim < 0 and -sub_mesh_dim > mesh_ndim
):
mesh_shape = global_mesh.shape
mesh_ndim = len(mesh_shape)
if local_mesh_dim >= mesh_ndim or (
local_mesh_dim < 0 and -local_mesh_dim > mesh_ndim
):
raise ValueError(
f"The local_mesh_dim should between (-{mesh_ndim}, {mesh_ndim}]"
)
if local_mesh_dim < 0:
local_mesh_dim += mesh_ndim
else:
raise ValueError(
"the args global_mesh, global_placements and local_mesh_dim should all be set."
f"The sub_mesh_dim should between (-{mesh_ndim}, {mesh_ndim}]"
)
if sub_mesh_dim < 0:
sub_mesh_dim += mesh_ndim

process_ids = np.array(global_mesh.process_ids).reshape(mesh_shape)
splitted_process_ids = np.split(
process_ids, mesh_shape[local_mesh_dim], axis=local_mesh_dim
process_ids, mesh_shape[sub_mesh_dim], axis=sub_mesh_dim
)
local_mesh_list = []
for process_ids in splitted_process_ids:
local_mesh_list.append(dist.ProcessMesh(process_ids))
sub_mesh_list = []
for sub_process_ids in splitted_process_ids:
sub_mesh_list.append(dist.ProcessMesh(sub_process_ids))

return sub_mesh_list


def _get_sub_meshes_and_local_placements(
global_mesh, global_placements, sub_mesh_dim
):
if global_mesh is None or sub_mesh_dim is None or global_placements is None:
raise ValueError(
"the args global_mesh, global_placements and local_mesh_dim should all be set."
)

sub_mesh_list = split_mesh(global_mesh, sub_mesh_dim)

local_placements = list(global_placements)
local_placements.pop(local_mesh_dim)
if local_placements == []:
local_placements.append(dist.Replicate())
return local_mesh_list, local_placements
if sub_mesh_dim < len(local_placements):
local_placements[sub_mesh_dim] = dist.Replicate()

return sub_mesh_list, local_placements


def cal_global_shape(local_shape, mesh, placements):
# assume the each rank has the same tensor shape for now,
# just use the local shape to calculate the global shape
global_shape = list(local_shape)
for idx, placement in enumerate(placements):
if placement.is_shard():
shard_dim = placement.get_dim()
local_dim_size = global_shape[shard_dim]
global_shape[shard_dim] = local_dim_size * mesh.shape[idx]
return global_shape


def moe_global_mesh_tensor(
local_tensor_list, mesh, placements, local_mesh_dim=-1
):
# assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape
local_mesh_list, local_placements = get_sub_meshes_from_global_mesh(
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
mesh, placements, local_mesh_dim
)

local_tensor_idx = mesh.process_ids.index(dist.get_rank())
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
local_coord = np.where(process_ids == dist.get_rank())
local_tensor_idx = local_coord[local_mesh_dim][0]
# local_tensor_idx = mesh.process_ids.index(dist.get_rank())
local_tensor = local_tensor_list[local_tensor_idx]
global_dims = list(local_tensor.shape)
for idx, placement in enumerate(placements):
if placement.is_shard():
shard_dim = placement.get_dim()
local_dim_size = global_dims[shard_dim]
global_dims[shard_dim] = local_dim_size * mesh.shape[idx]

if paddle.in_dynamic_mode():
global_dims = cal_global_shape(
local_tensor._local_value().shape, mesh, placements
)
resharded_local_tensor_list = []
for i, tensor in enumerate(local_tensor_list):
tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
if (
tensor.placements != local_placements
not check_placements_equal(tensor.placements, local_placements)
or tensor.process_mesh != local_mesh_list[i]
):
resharded_local_tensor_list.append(
reshard(tensor, local_mesh_list[i], local_placements)
)
resharded_local_tensor_list[
-1
].get_tensor()._unsafe_set_skip_check_mesh(True)
else:
resharded_local_tensor_list.append(tensor)

return _moe_global_mesh_tensor.apply(
resharded_local_tensor_list,
local_mesh_list,
local_placements,
local_tensor_idx,
global_dims,
mesh,
placements,
)
elif paddle.framework.in_pir_mode():
global_dims = cal_global_shape(
local_tensor._local_shape, mesh, placements
)
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
local_tensor_list,
local_mesh_list,
Expand All @@ -487,11 +506,13 @@ def forward(
dist_tensor,
local_mesh_list=None,
local_placements=None,
local_mesh_dim=None,
global_mesh=None,
global_placements=None,
):
ctx.local_mesh_list = copy.deepcopy(local_mesh_list)
ctx.local_placements = local_placements
ctx.local_mesh_dim = local_mesh_dim
ctx.global_mesh = copy.deepcopy(global_mesh)
ctx.global_placements = global_placements
ctx.global_shape = dist_tensor.shape
Expand Down Expand Up @@ -532,20 +553,24 @@ def forward(
place=place,
)
local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
local_tensor.stop_gradient = False
local_tensor.stop_gradient = dist_tensor.stop_gradient
local_tensor_list.append(local_tensor)
return local_tensor_list

@staticmethod
def backward(ctx, *grad_tensor):
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
idx = ctx.global_mesh.process_ids.index(dist.get_rank())
local_grad = grad_tensor[idx]
# idx = ctx.global_mesh.process_ids.index(dist.get_rank())
mesh = ctx.global_mesh
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
local_coord = np.where(process_ids == dist.get_rank())
local_tensor_idx = local_coord[ctx.local_mesh_dim][0]
local_grad = grad_tensor[local_tensor_idx]
global_tensor = paddle.Tensor(
local_grad._local_value(),
dims=ctx.global_shape,
process_mesh=ctx.global_mesh,
process_mesh=mesh,
placements=ctx.global_placements,
place=place,
)
Expand All @@ -558,7 +583,7 @@ def moe_sub_mesh_tensors(
"""
Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
"""
local_mesh_list, local_placements = get_sub_meshes_from_global_mesh(
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
global_mesh, global_placements, local_mesh_dim
)

Expand All @@ -567,6 +592,7 @@ def moe_sub_mesh_tensors(
dist_tensor,
local_mesh_list,
local_placements,
local_mesh_dim,
global_mesh,
global_placements,
)
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,6 @@ def _parallel_pir(self, mode):

# re-run apply_mix2dist_pass to dist accumulator.
apply_mix2dist_pass(dist_program)
# print('program', startup_program, dist_program, flush=1)

# Part 2: Parallelism search (for full auto-parallel)
# NOTE make all parallelis search logic work as Pass,
Expand All @@ -791,7 +790,7 @@ def _parallel_pir(self, mode):

# Part 3: Graph partition
# TODO(JZ-LIANG) Step 3.1: Partition Pass
# insert reshard op if operand tensor's placements if different from what the cumsumer op need.
# insert reshard op if operand tensor's placements is different from what the cumsumer op need.
# Partition the computation graph into different pipeline stage if need.
apply_partition_pass(dist_program)

Expand Down
Loading