1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
17+ from typing import TYPE_CHECKING , TypeVar
18+
1519import numpy as np
1620
1721import paddle
2327 convert_tensor_to_object ,
2428)
2529
30+ if TYPE_CHECKING :
31+ from paddle import Tensor
32+ from paddle .base .core import task
33+ from paddle .distributed .communication .group import Group
34+
35+ _T = TypeVar ("_T" )
36+
2637
27- def all_gather (tensor_list , tensor , group = None , sync_op = True ):
38+ def all_gather (
39+ tensor_list : list [Tensor ],
40+ tensor : Tensor ,
41+ group : Group | None = None ,
42+ sync_op : bool = True ,
43+ ) -> task | None :
2844 """
2945
3046 Gather tensors from all participators and all get the result. As shown
@@ -42,7 +58,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
4258 should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
4359 tensor (Tensor): The Tensor to send. Its data type
4460 should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
45- group (Group, optional): The group instance return by new_group or None for global default group.
61+ group (Group|None , optional): The group instance return by new_group or None for global default group.
4662 sync_op (bool, optional): Whether this op is a sync op. The default value is True.
4763
4864 Returns:
@@ -68,7 +84,9 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
6884 return stream .all_gather (tensor_list , tensor , group , sync_op )
6985
7086
71- def all_gather_object (object_list , obj , group = None ):
87+ def all_gather_object (
88+ object_list : list [_T ], obj : _T , group : Group = None
89+ ) -> None :
7290 """
7391
7492 Gather picklable objects from all participators and all get the result. Similar to all_gather(), but python object can be passed in.
0 commit comments