@@ -495,6 +495,7 @@ def _cal_global_shape(local_shape, mesh, placements):
495495def moe_global_mesh_tensor (
496496 local_tensor_list , mesh , placements , local_mesh_dim = - 1
497497):
498+ placements = copy .deepcopy (placements )
498499 local_mesh_list , local_placements = _get_sub_meshes_and_local_placements (
499500 mesh , placements , local_mesh_dim
500501 )
@@ -548,16 +549,17 @@ def moe_global_mesh_tensor(
548549 global_dims = _cal_global_shape (
549550 local_tensor ._local_shape , mesh , placements
550551 )
551- return paddle .jit .dy2static .py_layer .StaticPyLayer (
552- _moe_global_mesh_tensor
553- ).apply (
552+ dist_tensor = paddle ._C_ops .moe_global_mesh_tensor (
554553 local_tensor_list ,
555554 local_mesh_list ,
556555 local_placements ,
557556 mesh ,
558557 placements ,
559558 global_dims ,
560559 )
560+ dist_tensor .stop_gradient = local_tensor_list [0 ].stop_gradient
561+ dist_tensor .persistable = local_tensor_list [0 ].persistable
562+ return dist_tensor
561563 else :
562564 raise NotImplementedError (
563565 "dtensor_from_local_list() are only supported in dynamic and pir mode."
@@ -691,6 +693,7 @@ def moe_sub_mesh_tensors(
691693 """
692694 Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
693695 """
696+ global_placements = copy .deepcopy (global_placements )
694697 local_mesh_list , local_placements = _get_sub_meshes_and_local_placements (
695698 global_mesh , global_placements , local_mesh_dim
696699 )
@@ -705,17 +708,17 @@ def moe_sub_mesh_tensors(
705708 global_placements ,
706709 )
707710 elif paddle .framework .in_pir_mode ():
708-
709- return paddle .jit .dy2static .py_layer .StaticPyLayer (
710- _moe_sub_mesh_tensors
711- ).apply (
711+ local_tensors = paddle ._C_ops .moe_sub_mesh_tensors (
712712 dist_tensor ,
713713 local_mesh_list ,
714714 local_placements ,
715- local_mesh_dim ,
716715 global_mesh ,
717716 global_placements ,
718717 )
718+ for local_tensor in local_tensors :
719+ local_tensor .stop_gradient = dist_tensor .stop_gradient
720+ local_tensor .persistable = dist_tensor .persistable
721+ return local_tensors
719722 else :
720723 raise NotImplementedError (
721724 "moe_sub_mesh_tensors is only supported in dynamic mode."
0 commit comments