Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 50 additions & 26 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
from collections.abc import Iterable
from types import FunctionType, MethodType
from typing import TYPE_CHECKING, Callable, TypeVar
from typing import TYPE_CHECKING, Callable, TypeVar, overload

import numpy as np
from typing_extensions import ParamSpec
Expand All @@ -49,6 +49,8 @@
_RetT = TypeVar("_RetT")

if TYPE_CHECKING:
from typing import Generator, Sequence

from paddle.static.amp.fp16_utils import AmpOptions

__all__ = []
Expand Down Expand Up @@ -106,7 +108,7 @@ def _global_flags():
return _global_flags_


def set_flags(flags):
def set_flags(flags: dict[str, bool | str | float]) -> None:
"""
This function sets the GFlags value in Paddle.
For FLAGS please refer to :ref:`en_guides_flags_flags`
Expand All @@ -131,7 +133,7 @@ def set_flags(flags):
)


def get_flags(flags):
def get_flags(flags: str | Sequence[str]) -> dict[str, bool | str | float]:
"""
This function gets the GFlags value in Paddle.
For FLAGS please refer to :ref:`en_guides_flags_flags`
Expand Down Expand Up @@ -404,7 +406,9 @@ def in_cinn_mode() -> bool:


@signature_safe_contextmanager
def ipu_shard_guard(index=-1, stage=-1):
def ipu_shard_guard(
index: int = -1, stage: int = -1
) -> Generator[None, None, None]:
"""
Used to shard the graph on IPUs. Set each Op run on which IPU in the sharding and which stage in the pipelining.

Expand Down Expand Up @@ -456,6 +460,20 @@ def ipu_shard_guard(index=-1, stage=-1):
global_ipu_stage = prev_ipu_stage


@overload
def set_ipu_shard(
call_func: Callable[_InputT, _RetT], index: int = ..., stage: int = ...
) -> Callable[_InputT, _RetT]:
...


@overload
def set_ipu_shard(
call_func: paddle.nn.Layer, index: int = ..., stage: int = ...
) -> paddle.nn.Layer:
...


def set_ipu_shard(call_func, index=-1, stage=-1):
"""
Shard the ipu with the given call function. Set every ops in call function to the given ipu sharding.
Expand All @@ -467,9 +485,9 @@ def set_ipu_shard(call_func, index=-1, stage=-1):

Args:
call_func(Layer|function): Specify the call function to be wrapped.
index(int, optional): Specify which ipu the Tensor is computed on, (such as 0, 1, 2, 3).
index(int, optional): Specify which ipu the Tensor is computed on, (such as '0, 1, 2, 3').
The default value is -1, which means the Op only run on IPU 0.
stage(int, optional): Specify the computation order of the sharded model(such as 0, 1, 2, 3).
stage(int, optional): Specify the computation order of the sharded model(such as '0, 1, 2, 3').
The sharded model will be computed from small to large. The default value is -1,
which means no pipelining computation order and run Ops in terms of graph.

Expand All @@ -489,8 +507,8 @@ def set_ipu_shard(call_func, index=-1, stage=-1):
>>> relu(a)
"""

def decorate(func):
def wrapper(*args, **kwargs):
def decorate(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
with ipu_shard_guard(index=index, stage=stage):
return func(*args, **kwargs)

Expand Down Expand Up @@ -843,7 +861,7 @@ def _custom_device_ids(device_type):
return device_ids


def is_compiled_with_xpu():
def is_compiled_with_xpu() -> bool:
"""
Whether this whl package can be used to run the model on XPU.

Expand All @@ -858,7 +876,7 @@ def is_compiled_with_xpu():
return core.is_compiled_with_xpu()


def disable_signal_handler():
def disable_signal_handler() -> None:
"""
Reset signal handler registered by Paddle.

Expand All @@ -884,7 +902,7 @@ def disable_signal_handler():
core.disable_signal_handler()


def is_compiled_with_cinn():
def is_compiled_with_cinn() -> bool:
"""
Whether this whl package can be used to run the model on CINN.

Expand All @@ -900,7 +918,7 @@ def is_compiled_with_cinn():
return core.is_compiled_with_cinn()


def is_compiled_with_cuda():
def is_compiled_with_cuda() -> bool:
"""
Whether this whl package can be used to run the model on GPU.

Expand All @@ -916,7 +934,7 @@ def is_compiled_with_cuda():
return core.is_compiled_with_cuda()


def is_compiled_with_distribute():
def is_compiled_with_distribute() -> bool:
"""
Whether this whl package can be used to run the model with distribute.

Expand All @@ -932,7 +950,7 @@ def is_compiled_with_distribute():
return core.is_compiled_with_distribute()


def is_compiled_with_rocm():
def is_compiled_with_rocm() -> bool:
"""
Whether this whl package can be used to run the model on AMD or Hygon GPU(ROCm).

Expand All @@ -948,7 +966,9 @@ def is_compiled_with_rocm():
return core.is_compiled_with_rocm()


def cuda_places(device_ids=None):
def cuda_places(
device_ids: Sequence[int] | None = None,
) -> list[core.CUDAPlace]:
"""
Note:
For multi-card tasks, please use `FLAGS_selected_gpus` environment variable to set the visible GPU device.
Expand Down Expand Up @@ -996,7 +1016,7 @@ def cuda_places(device_ids=None):
return [core.CUDAPlace(dev_id) for dev_id in device_ids]


def xpu_places(device_ids=None):
def xpu_places(device_ids: Sequence[int] | None = None) -> list[core.XPUPlace]:
"""
**Note**:
For multi-card tasks, please use `FLAGS_selected_xpus` environment variable to set the visible XPU device.
Expand Down Expand Up @@ -1035,7 +1055,7 @@ def xpu_places(device_ids=None):
return [core.XPUPlace(dev_id) for dev_id in device_ids]


def cpu_places(device_count=None):
def cpu_places(device_count: int | None = None) -> list[core.CPUPlace]:
"""
This function creates a list of :code:`paddle.CPUPlace` objects, and returns the created list.

Expand Down Expand Up @@ -1069,7 +1089,9 @@ def cpu_places(device_count=None):
return [core.CPUPlace()] * device_count


def cuda_pinned_places(device_count=None):
def cuda_pinned_places(
device_count: int | None = None,
) -> list[core.CUDAPinnedPlace]:
"""
This function creates a list of :code:`base.CUDAPinnedPlace` objects.

Expand Down Expand Up @@ -1130,7 +1152,7 @@ def name(self):


@signature_safe_contextmanager
def name_scope(prefix=None):
def name_scope(prefix: str | None = None) -> Generator[None, None, None]:
"""

Generate hierarchical name prefix for the operators in Static Graph.
Expand Down Expand Up @@ -7784,7 +7806,7 @@ def _copy_to(self, device, blocking):
_startup_program_._is_start_up_program_ = True


def default_startup_program():
def default_startup_program() -> Program:
"""
Get default/global startup program.

Expand Down Expand Up @@ -7813,7 +7835,7 @@ def default_startup_program():
return _startup_program_


def default_main_program():
def default_main_program() -> Program:
"""
This API can be used to get ``default main program`` which store the
descriptions of Ops and tensors.
Expand Down Expand Up @@ -7850,7 +7872,7 @@ def default_main_program():
return _main_program_


def switch_main_program(program):
def switch_main_program(program: Program) -> Program:
"""
Switch the main program to a new program.

Expand All @@ -7866,7 +7888,7 @@ def switch_main_program(program):
return prev_program


def switch_startup_program(program):
def switch_startup_program(program: Program) -> Program:
"""
Switch the startup program to a new program
Args:
Expand All @@ -7882,7 +7904,9 @@ def switch_startup_program(program):


@signature_safe_contextmanager
def program_guard(main_program, startup_program=None):
def program_guard(
main_program: Program, startup_program: Program | None = None
) -> Generator[None, None, None]:
"""
:api_attr: Static Graph

Expand Down Expand Up @@ -8016,7 +8040,7 @@ def switch_device(device):


@signature_safe_contextmanager
def device_guard(device=None):
def device_guard(device: str | None = None) -> Generator[None, None, None]:
"""

Note:
Expand Down Expand Up @@ -8233,7 +8257,7 @@ def dtype_to_str(in_dtype):
elif in_dtype == core.VarDesc.VarType.COMPLEX128:
return "complex128"
else:
raise TypeError(f"got unspport data type for promotion: {in_dtype}.")
raise TypeError(f"got unsupport data type for promotion: {in_dtype}.")


def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype):
Expand Down