@@ -407,7 +407,7 @@ def split_mesh(global_mesh: dist.ProcessMesh, sub_mesh_dim: int):
407407 return sub_mesh_list
408408
409409
410- def get_sub_meshes_and_local_placements (
410+ def _get_sub_meshes_and_local_placements (
411411 global_mesh , global_placements , sub_mesh_dim
412412):
413413 if global_mesh is None or sub_mesh_dim is None or global_placements is None :
@@ -439,7 +439,7 @@ def cal_global_shape(local_shape, mesh, placements):
439439def moe_global_mesh_tensor (
440440 local_tensor_list , mesh , placements , local_mesh_dim = - 1
441441):
442- local_mesh_list , local_placements = get_sub_meshes_and_local_placements (
442+ local_mesh_list , local_placements = _get_sub_meshes_and_local_placements (
443443 mesh , placements , local_mesh_dim
444444 )
445445 process_ids = np .array (mesh .process_ids ).reshape (mesh .shape )
@@ -583,7 +583,7 @@ def moe_sub_mesh_tensors(
583583 """
584584 Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
585585 """
586- local_mesh_list , local_placements = get_sub_meshes_and_local_placements (
586+ local_mesh_list , local_placements = _get_sub_meshes_and_local_placements (
587587 global_mesh , global_placements , local_mesh_dim
588588 )
589589
@@ -3014,10 +3014,10 @@ def __call__(self):
30143014
30153015
30163016def shard_dataloader (
3017- dataloader : paddle . io . DataLoader ,
3018- meshes : ProcessMesh | list [ ProcessMesh ] | tuple [ProcessMesh ],
3019- input_keys : list [ str ] | tuple [str ] | None = None ,
3020- shard_dims : list | tuple | str | int | None = None ,
3017+ dataloader : DataLoader ,
3018+ meshes : ProcessMesh | Sequence [ProcessMesh ],
3019+ input_keys : Sequence [str ] | None = None ,
3020+ shard_dims : Sequence [ str ] | Sequence [ int ] | str | int | None = None ,
30213021 is_dataset_splitted : bool = False ,
30223022) -> ShardDataloader :
30233023 """
0 commit comments