Skip to content

Commit d84c06f

Browse files
pkuzyczhangbo9674
andauthored
[Dist Dialect] Simple MoE training in PIR (#66750)
* refine and rename the MoE apis * add error test in ut * add pass and unit test for pir moe model add sub_to_global_reshard function, the loss of PIR MoE demo is equal to dygraph auto parallel refine and rename the MoE apis * adapt arg type in shard_dataloader * adapt grad_clip and CBroadcastOp in pir * refine moe api to support multi-dimension process mesh * refine moe api * add spmd rule for ont_hot and cumsum in yaml * revert shard_dataloader * add unit test for reshape_grad with resharding x_grad * fix bug in unit test * revert shard_dataloader api args typing * add * add --------- Co-authored-by: zhangbo9674 <[email protected]>
1 parent 2c3d98b commit d84c06f

File tree

15 files changed

+359
-64
lines changed

15 files changed

+359
-64
lines changed

paddle/fluid/pir/dialect/operator/utils/utils.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ const std::unordered_set<std::string> LegacyOpList = {
3939
LoadCombineOp::name(),
4040
CConcatOp::name(),
4141
CBroadcast_Op::name(),
42+
CBroadcastOp::name(),
4243
CSyncCommStream_Op::name(),
4344
DistributedPushSparseOp::name(),
4445
SendV2Op::name(),

paddle/phi/ops/yaml/ops.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,7 @@
12011201
output : Tensor(out)
12021202
infer_meta :
12031203
func : CumScalarAxisInferMeta
1204+
spmd_rule : CumSumInferSpmdDynamic
12041205
kernel :
12051206
func : cumsum
12061207
data_type : x
@@ -3511,6 +3512,7 @@
35113512
output : Tensor(out)
35123513
infer_meta :
35133514
func : OneHotInferMeta
3515+
spmd_rule : OneHotInferSpmdDynamic
35143516
kernel :
35153517
func : one_hot
35163518
traits : paddle::dialect::ForwardOnlyTrait

python/paddle/distributed/auto_parallel/api.py

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def forward(
329329
ctx,
330330
local_tensor_list,
331331
local_mesh_list,
332+
local_placements,
332333
idx,
333334
global_dims,
334335
mesh,
@@ -338,17 +339,15 @@ def forward(
338339
if local_tensor.is_dist():
339340
local_mesh = local_tensor.process_mesh
340341
local_val = local_tensor._local_value()
341-
local_placement = local_tensor.placements[0]
342342
else:
343343
local_val = local_tensor
344344
local_mesh = None
345-
local_placement = dist.Replicate()
346345

347346
ctx.global_mesh = copy.deepcopy(mesh)
348347
ctx.placements = placements
349348
ctx.local_dims = local_tensor.shape
350349
ctx.local_mesh_list = copy.deepcopy(local_mesh_list)
351-
ctx.local_placement = local_placement
350+
ctx.local_placements = local_placements
352351

353352
place = paddle.framework._current_expected_place()
354353
place = paddle.framework._get_paddle_place(place)
@@ -360,7 +359,7 @@ def forward(
360359
placements=placements,
361360
place=place,
362361
)
363-
global_tensor.stop_gradient = False
362+
global_tensor.stop_gradient = local_tensor.stop_gradient
364363
return global_tensor
365364

366365
@staticmethod
@@ -377,91 +376,111 @@ def backward(ctx, grad_tensor):
377376
grad_tensor._local_value(),
378377
dims=ctx.local_dims,
379378
process_mesh=local_mesh,
380-
placements=[ctx.local_placement],
379+
placements=ctx.local_placements,
381380
place=place,
382381
)
383382
)
384383
out[-1].get_tensor()._unsafe_set_skip_check_mesh(True)
385384
return out
386385

387386

388-
def get_sub_meshes_from_global_mesh(
389-
global_mesh, global_placements, local_mesh_dim
390-
):
391-
if (
392-
global_mesh is not None
393-
and local_mesh_dim is not None
394-
and global_placements is not None
387+
def split_mesh(global_mesh: dist.ProcessMesh, sub_mesh_dim: int):
388+
mesh_shape = global_mesh.shape
389+
mesh_ndim = len(mesh_shape)
390+
if sub_mesh_dim >= mesh_ndim or (
391+
sub_mesh_dim < 0 and -sub_mesh_dim > mesh_ndim
395392
):
396-
mesh_shape = global_mesh.shape
397-
mesh_ndim = len(mesh_shape)
398-
if local_mesh_dim >= mesh_ndim or (
399-
local_mesh_dim < 0 and -local_mesh_dim > mesh_ndim
400-
):
401-
raise ValueError(
402-
f"The local_mesh_dim should between (-{mesh_ndim}, {mesh_ndim}]"
403-
)
404-
if local_mesh_dim < 0:
405-
local_mesh_dim += mesh_ndim
406-
else:
407393
raise ValueError(
408-
"the args global_mesh, global_placements and local_mesh_dim should all be set."
394+
f"The sub_mesh_dim should between (-{mesh_ndim}, {mesh_ndim}]"
409395
)
396+
if sub_mesh_dim < 0:
397+
sub_mesh_dim += mesh_ndim
410398

411399
process_ids = np.array(global_mesh.process_ids).reshape(mesh_shape)
412400
splitted_process_ids = np.split(
413-
process_ids, mesh_shape[local_mesh_dim], axis=local_mesh_dim
401+
process_ids, mesh_shape[sub_mesh_dim], axis=sub_mesh_dim
414402
)
415-
local_mesh_list = []
416-
for process_ids in splitted_process_ids:
417-
local_mesh_list.append(dist.ProcessMesh(process_ids))
403+
sub_mesh_list = []
404+
for sub_process_ids in splitted_process_ids:
405+
sub_mesh_list.append(dist.ProcessMesh(sub_process_ids))
406+
407+
return sub_mesh_list
408+
409+
410+
def _get_sub_meshes_and_local_placements(
411+
global_mesh, global_placements, sub_mesh_dim
412+
):
413+
if global_mesh is None or sub_mesh_dim is None or global_placements is None:
414+
raise ValueError(
415+
"the args global_mesh, global_placements and local_mesh_dim should all be set."
416+
)
417+
418+
sub_mesh_list = split_mesh(global_mesh, sub_mesh_dim)
419+
418420
local_placements = list(global_placements)
419-
local_placements.pop(local_mesh_dim)
420-
if local_placements == []:
421-
local_placements.append(dist.Replicate())
422-
return local_mesh_list, local_placements
421+
if sub_mesh_dim < len(local_placements):
422+
local_placements[sub_mesh_dim] = dist.Replicate()
423+
424+
return sub_mesh_list, local_placements
425+
426+
427+
def cal_global_shape(local_shape, mesh, placements):
428+
# assume the each rank has the same tensor shape for now,
429+
# just use the local shape to calculate the global shape
430+
global_shape = list(local_shape)
431+
for idx, placement in enumerate(placements):
432+
if placement.is_shard():
433+
shard_dim = placement.get_dim()
434+
local_dim_size = global_shape[shard_dim]
435+
global_shape[shard_dim] = local_dim_size * mesh.shape[idx]
436+
return global_shape
423437

424438

425439
def moe_global_mesh_tensor(
426440
local_tensor_list, mesh, placements, local_mesh_dim=-1
427441
):
428-
# assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape
429-
local_mesh_list, local_placements = get_sub_meshes_from_global_mesh(
442+
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
430443
mesh, placements, local_mesh_dim
431444
)
432-
433-
local_tensor_idx = mesh.process_ids.index(dist.get_rank())
445+
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
446+
local_coord = np.where(process_ids == dist.get_rank())
447+
local_tensor_idx = local_coord[local_mesh_dim][0]
448+
# local_tensor_idx = mesh.process_ids.index(dist.get_rank())
434449
local_tensor = local_tensor_list[local_tensor_idx]
435-
global_dims = list(local_tensor.shape)
436-
for idx, placement in enumerate(placements):
437-
if placement.is_shard():
438-
shard_dim = placement.get_dim()
439-
local_dim_size = global_dims[shard_dim]
440-
global_dims[shard_dim] = local_dim_size * mesh.shape[idx]
441450

442451
if paddle.in_dynamic_mode():
452+
global_dims = cal_global_shape(
453+
local_tensor._local_value().shape, mesh, placements
454+
)
443455
resharded_local_tensor_list = []
444456
for i, tensor in enumerate(local_tensor_list):
445457
tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
446458
if (
447-
tensor.placements != local_placements
459+
not check_placements_equal(tensor.placements, local_placements)
448460
or tensor.process_mesh != local_mesh_list[i]
449461
):
450462
resharded_local_tensor_list.append(
451463
reshard(tensor, local_mesh_list[i], local_placements)
452464
)
465+
resharded_local_tensor_list[
466+
-1
467+
].get_tensor()._unsafe_set_skip_check_mesh(True)
453468
else:
454469
resharded_local_tensor_list.append(tensor)
455470

456471
return _moe_global_mesh_tensor.apply(
457472
resharded_local_tensor_list,
458473
local_mesh_list,
474+
local_placements,
459475
local_tensor_idx,
460476
global_dims,
461477
mesh,
462478
placements,
463479
)
464480
elif paddle.framework.in_pir_mode():
481+
global_dims = cal_global_shape(
482+
local_tensor._local_shape, mesh, placements
483+
)
465484
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
466485
local_tensor_list,
467486
local_mesh_list,
@@ -487,11 +506,13 @@ def forward(
487506
dist_tensor,
488507
local_mesh_list=None,
489508
local_placements=None,
509+
local_mesh_dim=None,
490510
global_mesh=None,
491511
global_placements=None,
492512
):
493513
ctx.local_mesh_list = copy.deepcopy(local_mesh_list)
494514
ctx.local_placements = local_placements
515+
ctx.local_mesh_dim = local_mesh_dim
495516
ctx.global_mesh = copy.deepcopy(global_mesh)
496517
ctx.global_placements = global_placements
497518
ctx.global_shape = dist_tensor.shape
@@ -532,20 +553,24 @@ def forward(
532553
place=place,
533554
)
534555
local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
535-
local_tensor.stop_gradient = False
556+
local_tensor.stop_gradient = dist_tensor.stop_gradient
536557
local_tensor_list.append(local_tensor)
537558
return local_tensor_list
538559

539560
@staticmethod
540561
def backward(ctx, *grad_tensor):
541562
place = paddle.framework._current_expected_place()
542563
place = paddle.framework._get_paddle_place(place)
543-
idx = ctx.global_mesh.process_ids.index(dist.get_rank())
544-
local_grad = grad_tensor[idx]
564+
# idx = ctx.global_mesh.process_ids.index(dist.get_rank())
565+
mesh = ctx.global_mesh
566+
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
567+
local_coord = np.where(process_ids == dist.get_rank())
568+
local_tensor_idx = local_coord[ctx.local_mesh_dim][0]
569+
local_grad = grad_tensor[local_tensor_idx]
545570
global_tensor = paddle.Tensor(
546571
local_grad._local_value(),
547572
dims=ctx.global_shape,
548-
process_mesh=ctx.global_mesh,
573+
process_mesh=mesh,
549574
placements=ctx.global_placements,
550575
place=place,
551576
)
@@ -558,7 +583,7 @@ def moe_sub_mesh_tensors(
558583
"""
559584
Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
560585
"""
561-
local_mesh_list, local_placements = get_sub_meshes_from_global_mesh(
586+
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
562587
global_mesh, global_placements, local_mesh_dim
563588
)
564589

@@ -567,6 +592,7 @@ def moe_sub_mesh_tensors(
567592
dist_tensor,
568593
local_mesh_list,
569594
local_placements,
595+
local_mesh_dim,
570596
global_mesh,
571597
global_placements,
572598
)

python/paddle/distributed/auto_parallel/static/engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,6 @@ def _parallel_pir(self, mode):
769769

770770
# re-run apply_mix2dist_pass to dist accumulator.
771771
apply_mix2dist_pass(dist_program)
772-
# print('program', startup_program, dist_program, flush=1)
773772

774773
# Part 2: Parallelism search (for full auto-parallel)
775774
# NOTE make all parallelis search logic work as Pass,
@@ -791,7 +790,7 @@ def _parallel_pir(self, mode):
791790

792791
# Part 3: Graph partition
793792
# TODO(JZ-LIANG) Step 3.1: Partition Pass
794-
# insert reshard op if operand tensor's placements if different from what the cumsumer op need.
793+
# insert reshard op if operand tensor's placements is different from what the cumsumer op need.
795794
# Partition the computation graph into different pipeline stage if need.
796795
apply_partition_pass(dist_program)
797796

python/paddle/distributed/auto_parallel/static/pir_pass.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,88 @@ def _remove_other_rank_params_grads(dist_params_grads):
284284
dist_params_grads.pop(idx)
285285

286286

287+
# Replace the specific MoE-related dist op with the
288+
# executable op in the dense program. In expert parallelism
289+
# of the MoE model, the process mesh of each expert is
290+
# different. Two specific apis are used to transform the
291+
# input tensor's global process mesh to the experts' local
292+
# process meshes, which will add two dist ops in the program.
293+
# The following two functions are used to replace the two
294+
# dist ops with the executable share_data_ ops.
295+
def replace_moe_sub_mesh_tensors(op):
296+
cur_rank = paddle.distributed.get_rank()
297+
in_value = op.operand_source(0)
298+
out_value = None
299+
out_idx = -1
300+
for idx, val in enumerate(op.results()):
301+
val_mesh = val.dist_attr().process_mesh
302+
if cur_rank in val_mesh.process_ids:
303+
assert (
304+
out_value is None
305+
), f'{op} has more than one results on rank {cur_rank}'
306+
out_value = val
307+
out_idx = idx
308+
309+
paddle.pir.set_insertion_point(op)
310+
local_value = paddle._C_ops.share_data_(in_value)
311+
local_value_type = paddle.base.libpaddle.pir.cvt_to_dist_type(
312+
out_value.type(), out_value.dist_attr()
313+
)
314+
local_value.set_type(local_value_type)
315+
out_value.replace_all_uses_with(local_value)
316+
317+
op_dist_attr = op.dist_attr
318+
share_data_op = local_value.get_defining_op()
319+
share_data_op.dist_attr = (
320+
paddle.base.libpaddle.pir.create_op_dist_attribute(
321+
op_dist_attr.process_mesh,
322+
[op_dist_attr.operand(0).as_tensor_dist_attr()],
323+
[op_dist_attr.result(out_idx).as_tensor_dist_attr()],
324+
)
325+
)
326+
327+
assert all(val.use_empty() for val in op.results())
328+
op.erase()
329+
330+
331+
def replace_moe_global_mesh_tensor(op):
332+
cur_rank = paddle.distributed.get_rank()
333+
out_value = op.result(0)
334+
in_value = None
335+
in_idx = -1
336+
for idx, val in enumerate(op.operands_source()):
337+
val_mesh = val.dist_attr().process_mesh
338+
if cur_rank not in val_mesh.process_ids:
339+
continue
340+
assert (
341+
in_value is None
342+
), f'{op} has more than one inputs on rank {cur_rank}'
343+
in_value = val
344+
in_idx = idx
345+
346+
paddle.pir.set_insertion_point(op)
347+
local_value = paddle._C_ops.share_data_(in_value)
348+
# local_value = paddle.assign(in_value)
349+
local_value_type = paddle.base.libpaddle.pir.cvt_to_dist_type(
350+
out_value.type(), out_value.dist_attr()
351+
)
352+
local_value.set_type(local_value_type)
353+
out_value.replace_all_uses_with(local_value)
354+
355+
op_dist_attr = op.dist_attr
356+
share_data_op = local_value.get_defining_op()
357+
share_data_op.dist_attr = (
358+
paddle.base.libpaddle.pir.create_op_dist_attribute(
359+
op_dist_attr.process_mesh,
360+
[op_dist_attr.operand(in_idx).as_tensor_dist_attr()],
361+
[op_dist_attr.result(0).as_tensor_dist_attr()],
362+
)
363+
)
364+
365+
assert all(val.use_empty() for val in op.results())
366+
op.erase()
367+
368+
287369
# pruning op and value not belong to cur rank
288370
def remove_other_rank_op_pass(dist_program, dist_params_grads):
289371
cur_rank = paddle.distributed.get_rank()
@@ -298,6 +380,12 @@ def remove_other_rank_op_pass(dist_program, dist_params_grads):
298380
if can_delete:
299381
op.erase()
300382
continue
383+
if op.name() == "dist_op.moe_sub_mesh_tensors":
384+
replace_moe_sub_mesh_tensors(op)
385+
continue
386+
if op.name() == "dist_op.moe_global_mesh_tensor":
387+
replace_moe_global_mesh_tensor(op)
388+
continue
301389
if cur_rank not in op.dist_attr.process_mesh.process_ids:
302390
op.erase()
303391
elif op.name() == "dist_op.reshard":

0 commit comments

Comments
 (0)