@@ -329,6 +329,7 @@ def forward(
329329 ctx ,
330330 local_tensor_list ,
331331 local_mesh_list ,
332+ local_placements ,
332333 idx ,
333334 global_dims ,
334335 mesh ,
@@ -338,17 +339,15 @@ def forward(
338339 if local_tensor .is_dist ():
339340 local_mesh = local_tensor .process_mesh
340341 local_val = local_tensor ._local_value ()
341- local_placement = local_tensor .placements [0 ]
342342 else :
343343 local_val = local_tensor
344344 local_mesh = None
345- local_placement = dist .Replicate ()
346345
347346 ctx .global_mesh = copy .deepcopy (mesh )
348347 ctx .placements = placements
349348 ctx .local_dims = local_tensor .shape
350349 ctx .local_mesh_list = copy .deepcopy (local_mesh_list )
351- ctx .local_placement = local_placement
350+ ctx .local_placements = local_placements
352351
353352 place = paddle .framework ._current_expected_place ()
354353 place = paddle .framework ._get_paddle_place (place )
@@ -360,7 +359,7 @@ def forward(
360359 placements = placements ,
361360 place = place ,
362361 )
363- global_tensor .stop_gradient = False
362+ global_tensor .stop_gradient = local_tensor . stop_gradient
364363 return global_tensor
365364
366365 @staticmethod
@@ -377,91 +376,111 @@ def backward(ctx, grad_tensor):
377376 grad_tensor ._local_value (),
378377 dims = ctx .local_dims ,
379378 process_mesh = local_mesh ,
380- placements = [ ctx .local_placement ] ,
379+ placements = ctx .local_placements ,
381380 place = place ,
382381 )
383382 )
384383 out [- 1 ].get_tensor ()._unsafe_set_skip_check_mesh (True )
385384 return out
386385
387386
388- def get_sub_meshes_from_global_mesh (
389- global_mesh , global_placements , local_mesh_dim
390- ):
391- if (
392- global_mesh is not None
393- and local_mesh_dim is not None
394- and global_placements is not None
387+ def split_mesh (global_mesh : dist .ProcessMesh , sub_mesh_dim : int ):
388+ mesh_shape = global_mesh .shape
389+ mesh_ndim = len (mesh_shape )
390+ if sub_mesh_dim >= mesh_ndim or (
391+ sub_mesh_dim < 0 and - sub_mesh_dim > mesh_ndim
395392 ):
396- mesh_shape = global_mesh .shape
397- mesh_ndim = len (mesh_shape )
398- if local_mesh_dim >= mesh_ndim or (
399- local_mesh_dim < 0 and - local_mesh_dim > mesh_ndim
400- ):
401- raise ValueError (
402- f"The local_mesh_dim should between (-{ mesh_ndim } , { mesh_ndim } ]"
403- )
404- if local_mesh_dim < 0 :
405- local_mesh_dim += mesh_ndim
406- else :
407393 raise ValueError (
408- "the args global_mesh, global_placements and local_mesh_dim should all be set. "
394+ f"The sub_mesh_dim should between (- { mesh_ndim } , { mesh_ndim } ] "
409395 )
396+ if sub_mesh_dim < 0 :
397+ sub_mesh_dim += mesh_ndim
410398
411399 process_ids = np .array (global_mesh .process_ids ).reshape (mesh_shape )
412400 splitted_process_ids = np .split (
413- process_ids , mesh_shape [local_mesh_dim ], axis = local_mesh_dim
401+ process_ids , mesh_shape [sub_mesh_dim ], axis = sub_mesh_dim
414402 )
415- local_mesh_list = []
416- for process_ids in splitted_process_ids :
417- local_mesh_list .append (dist .ProcessMesh (process_ids ))
403+ sub_mesh_list = []
404+ for sub_process_ids in splitted_process_ids :
405+ sub_mesh_list .append (dist .ProcessMesh (sub_process_ids ))
406+
407+ return sub_mesh_list
408+
409+
410+ def _get_sub_meshes_and_local_placements (
411+ global_mesh , global_placements , sub_mesh_dim
412+ ):
413+ if global_mesh is None or sub_mesh_dim is None or global_placements is None :
414+ raise ValueError (
415+ "the args global_mesh, global_placements and local_mesh_dim should all be set."
416+ )
417+
418+ sub_mesh_list = split_mesh (global_mesh , sub_mesh_dim )
419+
418420 local_placements = list (global_placements )
419- local_placements .pop (local_mesh_dim )
420- if local_placements == []:
421- local_placements .append (dist .Replicate ())
422- return local_mesh_list , local_placements
421+ if sub_mesh_dim < len (local_placements ):
422+ local_placements [sub_mesh_dim ] = dist .Replicate ()
423+
424+ return sub_mesh_list , local_placements
425+
426+
427+ def cal_global_shape (local_shape , mesh , placements ):
428+ # assume the each rank has the same tensor shape for now,
429+ # just use the local shape to calculate the global shape
430+ global_shape = list (local_shape )
431+ for idx , placement in enumerate (placements ):
432+ if placement .is_shard ():
433+ shard_dim = placement .get_dim ()
434+ local_dim_size = global_shape [shard_dim ]
435+ global_shape [shard_dim ] = local_dim_size * mesh .shape [idx ]
436+ return global_shape
423437
424438
425439def moe_global_mesh_tensor (
426440 local_tensor_list , mesh , placements , local_mesh_dim = - 1
427441):
428- # assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape
429- local_mesh_list , local_placements = get_sub_meshes_from_global_mesh (
442+ local_mesh_list , local_placements = _get_sub_meshes_and_local_placements (
430443 mesh , placements , local_mesh_dim
431444 )
432-
433- local_tensor_idx = mesh .process_ids .index (dist .get_rank ())
445+ process_ids = np .array (mesh .process_ids ).reshape (mesh .shape )
446+ local_coord = np .where (process_ids == dist .get_rank ())
447+ local_tensor_idx = local_coord [local_mesh_dim ][0 ]
448+ # local_tensor_idx = mesh.process_ids.index(dist.get_rank())
434449 local_tensor = local_tensor_list [local_tensor_idx ]
435- global_dims = list (local_tensor .shape )
436- for idx , placement in enumerate (placements ):
437- if placement .is_shard ():
438- shard_dim = placement .get_dim ()
439- local_dim_size = global_dims [shard_dim ]
440- global_dims [shard_dim ] = local_dim_size * mesh .shape [idx ]
441450
442451 if paddle .in_dynamic_mode ():
452+ global_dims = cal_global_shape (
453+ local_tensor ._local_value ().shape , mesh , placements
454+ )
443455 resharded_local_tensor_list = []
444456 for i , tensor in enumerate (local_tensor_list ):
445457 tensor .get_tensor ()._unsafe_set_skip_check_mesh (True )
446458 if (
447- tensor .placements != local_placements
459+ not check_placements_equal ( tensor .placements , local_placements )
448460 or tensor .process_mesh != local_mesh_list [i ]
449461 ):
450462 resharded_local_tensor_list .append (
451463 reshard (tensor , local_mesh_list [i ], local_placements )
452464 )
465+ resharded_local_tensor_list [
466+ - 1
467+ ].get_tensor ()._unsafe_set_skip_check_mesh (True )
453468 else :
454469 resharded_local_tensor_list .append (tensor )
455470
456471 return _moe_global_mesh_tensor .apply (
457472 resharded_local_tensor_list ,
458473 local_mesh_list ,
474+ local_placements ,
459475 local_tensor_idx ,
460476 global_dims ,
461477 mesh ,
462478 placements ,
463479 )
464480 elif paddle .framework .in_pir_mode ():
481+ global_dims = cal_global_shape (
482+ local_tensor ._local_shape , mesh , placements
483+ )
465484 dist_tensor = paddle ._C_ops .moe_global_mesh_tensor (
466485 local_tensor_list ,
467486 local_mesh_list ,
@@ -487,11 +506,13 @@ def forward(
487506 dist_tensor ,
488507 local_mesh_list = None ,
489508 local_placements = None ,
509+ local_mesh_dim = None ,
490510 global_mesh = None ,
491511 global_placements = None ,
492512 ):
493513 ctx .local_mesh_list = copy .deepcopy (local_mesh_list )
494514 ctx .local_placements = local_placements
515+ ctx .local_mesh_dim = local_mesh_dim
495516 ctx .global_mesh = copy .deepcopy (global_mesh )
496517 ctx .global_placements = global_placements
497518 ctx .global_shape = dist_tensor .shape
@@ -532,20 +553,24 @@ def forward(
532553 place = place ,
533554 )
534555 local_tensor .get_tensor ()._unsafe_set_skip_check_mesh (True )
535- local_tensor .stop_gradient = False
556+ local_tensor .stop_gradient = dist_tensor . stop_gradient
536557 local_tensor_list .append (local_tensor )
537558 return local_tensor_list
538559
539560 @staticmethod
540561 def backward (ctx , * grad_tensor ):
541562 place = paddle .framework ._current_expected_place ()
542563 place = paddle .framework ._get_paddle_place (place )
543- idx = ctx .global_mesh .process_ids .index (dist .get_rank ())
544- local_grad = grad_tensor [idx ]
564+ # idx = ctx.global_mesh.process_ids.index(dist.get_rank())
565+ mesh = ctx .global_mesh
566+ process_ids = np .array (mesh .process_ids ).reshape (mesh .shape )
567+ local_coord = np .where (process_ids == dist .get_rank ())
568+ local_tensor_idx = local_coord [ctx .local_mesh_dim ][0 ]
569+ local_grad = grad_tensor [local_tensor_idx ]
545570 global_tensor = paddle .Tensor (
546571 local_grad ._local_value (),
547572 dims = ctx .global_shape ,
548- process_mesh = ctx . global_mesh ,
573+ process_mesh = mesh ,
549574 placements = ctx .global_placements ,
550575 place = place ,
551576 )
@@ -558,7 +583,7 @@ def moe_sub_mesh_tensors(
558583 """
559584 Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
560585 """
561- local_mesh_list , local_placements = get_sub_meshes_from_global_mesh (
586+ local_mesh_list , local_placements = _get_sub_meshes_and_local_placements (
562587 global_mesh , global_placements , local_mesh_dim
563588 )
564589
@@ -567,6 +592,7 @@ def moe_sub_mesh_tensors(
567592 dist_tensor ,
568593 local_mesh_list ,
569594 local_placements ,
595+ local_mesh_dim ,
570596 global_mesh ,
571597 global_placements ,
572598 )
0 commit comments