Skip to content

Commit af4f1e7

Browse files
committed
revert shard_dataloader api args typing
1 parent 5aa6565 commit af4f1e7

File tree

1 file changed

+7
-7
lines changed
  • python/paddle/distributed/auto_parallel

1 file changed

+7
-7
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def split_mesh(global_mesh: dist.ProcessMesh, sub_mesh_dim: int):
407407
return sub_mesh_list
408408

409409

410-
def get_sub_meshes_and_local_placements(
410+
def _get_sub_meshes_and_local_placements(
411411
global_mesh, global_placements, sub_mesh_dim
412412
):
413413
if global_mesh is None or sub_mesh_dim is None or global_placements is None:
@@ -439,7 +439,7 @@ def cal_global_shape(local_shape, mesh, placements):
439439
def moe_global_mesh_tensor(
440440
local_tensor_list, mesh, placements, local_mesh_dim=-1
441441
):
442-
local_mesh_list, local_placements = get_sub_meshes_and_local_placements(
442+
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
443443
mesh, placements, local_mesh_dim
444444
)
445445
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
@@ -583,7 +583,7 @@ def moe_sub_mesh_tensors(
583583
"""
584584
Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
585585
"""
586-
local_mesh_list, local_placements = get_sub_meshes_and_local_placements(
586+
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
587587
global_mesh, global_placements, local_mesh_dim
588588
)
589589

@@ -3014,10 +3014,10 @@ def __call__(self):
30143014

30153015

30163016
def shard_dataloader(
3017-
dataloader: paddle.io.DataLoader,
3018-
meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh],
3019-
input_keys: list[str] | tuple[str] | None = None,
3020-
shard_dims: list | tuple | str | int | None = None,
3017+
dataloader: DataLoader,
3018+
meshes: ProcessMesh | Sequence[ProcessMesh],
3019+
input_keys: Sequence[str] | None = None,
3020+
shard_dims: Sequence[str] | Sequence[int] | str | int | None = None,
30213021
is_dataset_splitted: bool = False,
30223022
) -> ShardDataloader:
30233023
"""

0 commit comments

Comments
 (0)