Skip to content

Commit b022ab4

Browse files
megeminilixcli
authored andcommitted
[Typing][C-12] Add type annotations for python/paddle/distributed/communication/all_gather.py (PaddlePaddle#66051)
1 parent a88809b commit b022ab4

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

python/paddle/base/core.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,9 @@ class iinfo:
9696

9797
def is_compiled_with_cuda() -> bool: ...
9898
def set_nan_inf_debug_path(arg0: str) -> None: ...
99+
100+
class task:
101+
def is_completed(self) -> bool: ...
102+
def is_sync(self) -> bool: ...
103+
def synchronize(self) -> None: ...
104+
def wait(self, timeout: int = 0) -> bool: ...

python/paddle/distributed/communication/all_gather.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING, TypeVar
18+
1519
import numpy as np
1620

1721
import paddle
@@ -23,8 +27,20 @@
2327
convert_tensor_to_object,
2428
)
2529

30+
if TYPE_CHECKING:
31+
from paddle import Tensor
32+
from paddle.base.core import task
33+
from paddle.distributed.communication.group import Group
34+
35+
_T = TypeVar("_T")
36+
2637

27-
def all_gather(tensor_list, tensor, group=None, sync_op=True):
38+
def all_gather(
39+
tensor_list: list[Tensor],
40+
tensor: Tensor,
41+
group: Group | None = None,
42+
sync_op: bool = True,
43+
) -> task | None:
2844
"""
2945
3046
Gather tensors from all participators and all get the result. As shown
@@ -42,7 +58,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
4258
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
4359
tensor (Tensor): The Tensor to send. Its data type
4460
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
45-
group (Group, optional): The group instance return by new_group or None for global default group.
61+
group (Group|None, optional): The group instance return by new_group or None for global default group.
4662
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
4763
4864
Returns:
@@ -68,7 +84,9 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
6884
return stream.all_gather(tensor_list, tensor, group, sync_op)
6985

7086

71-
def all_gather_object(object_list, obj, group=None):
87+
def all_gather_object(
88+
object_list: list[_T], obj: _T, group: Group = None
89+
) -> None:
7290
"""
7391
7492
Gather picklable objects from all participators and all get the result. Similar to all_gather(), but python object can be passed in.

0 commit comments

Comments
 (0)