Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
to_static,
unshard_dtensor,
)
from .auto_parallel.interface import get_mesh, set_mesh
from .auto_parallel.intermediate.parallelize import parallelize
from .auto_parallel.intermediate.pipeline_parallel import SplitPoint
from .auto_parallel.intermediate.tensor_parallel import (
Expand Down Expand Up @@ -199,4 +200,6 @@
"PrepareLayerOutput",
"PrepareLayerInput",
"SplitPoint",
"set_mesh",
"get_mesh",
]
42 changes: 40 additions & 2 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,50 @@ def fetch(tensor, name=None, logging=False):
_g_mesh = None


def get_mesh():
def get_mesh() -> paddle.distributed.ProcessMesh:
"""
Get the global mesh set by set_mesh.

Returns:
mesh (paddle.distributed.ProcessMesh): the global mesh.

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=["dp", "mp", "pp"])
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> dist.auto_parallel.set_mesh(mesh)
>>> mesh = dist.auto_parallel.get_mesh()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 {test_case}.py
"""
global _g_mesh
return _g_mesh


def set_mesh(mesh):
def set_mesh(mesh: paddle.distributed.ProcessMesh) -> None:
"""
Set the global mesh.

Args:
mesh (paddle.distributed.ProcessMesh): global mesh to be set.

Returns:
None

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=["dp", "mp", "pp"])
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> dist.auto_parallel.set_mesh(mesh)
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 {test_case}.py
"""
global _g_mesh
_g_mesh = mesh

Expand Down