1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- # TODO: define statistical functions of a tensor
15+ from __future__ import annotations
16+
17+ from typing import TYPE_CHECKING , Literal , Sequence , overload
18+
19+ from typing_extensions import TypeAlias
1620
1721import paddle
1822from paddle import _C_ops
2731from .math import _get_reduce_axis_with_tensor
2832from .search import where
2933
34+ if TYPE_CHECKING :
35+ from paddle import Tensor
36+
37+ _Interpolation : TypeAlias = Literal [
38+ 'linear' , 'higher' , 'lower' , 'midpoint' , 'nearest'
39+ ]
3040__all__ = []
3141
3242
33- def mean (x , axis = None , keepdim = False , name = None ):
43+ def mean (
44+ x : Tensor ,
45+ axis : int | Sequence [int ] | None = None ,
46+ keepdim : bool = False ,
47+ name : str | None = None ,
48+ ) -> Tensor :
3449 """
3550 Computes the mean of the input tensor's elements along ``axis``.
3651
3752 Args:
3853 x (Tensor): The input Tensor with data type float32, float64.
39- axis (int|list|tuple, optional): The axis along which to perform mean
54+ axis (int|list|tuple|None , optional): The axis along which to perform mean
4055 calculations. ``axis`` should be int, list(int) or tuple(int). If
4156 ``axis`` is a list/tuple of dimension(s), mean is calculated along
4257 all element(s) of ``axis`` . ``axis`` or element(s) of ``axis``
@@ -49,7 +64,7 @@ def mean(x, axis=None, keepdim=False, name=None):
4964 the output Tensor is the same as ``x`` except in the reduced
5065 dimensions(it is of size 1 in this case). Otherwise, the shape of
5166 the output Tensor is squeezed in ``axis`` . Default is False.
52- name (str, optional): Name for the operation (optional, default is None).
67+ name (str|None , optional): Name for the operation (optional, default is None).
5368 For more information, please refer to :ref:`api_guide_Name`.
5469
5570 Returns:
@@ -121,21 +136,27 @@ def mean(x, axis=None, keepdim=False, name=None):
121136 return out
122137
123138
124- def var (x , axis = None , unbiased = True , keepdim = False , name = None ):
139+ def var (
140+ x : Tensor ,
141+ axis : int | Sequence [int ] | None = None ,
142+ unbiased : bool = True ,
143+ keepdim : bool = False ,
144+ name : str | None = None ,
145+ ) -> Tensor :
125146 """
126147 Computes the variance of ``x`` along ``axis`` .
127148
128149 Args:
129150 x (Tensor): The input Tensor with data type float16, float32, float64.
130- axis (int|list|tuple, optional): The axis along which to perform variance calculations. ``axis`` should be int, list(int) or tuple(int).
151+ axis (int|list|tuple|None , optional): The axis along which to perform variance calculations. ``axis`` should be int, list(int) or tuple(int).
131152
132153 - If ``axis`` is a list/tuple of dimension(s), variance is calculated along all element(s) of ``axis`` . ``axis`` or element(s) of ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
133154 - If ``axis`` or element(s) of ``axis`` is less than 0, it works the same way as :math:`axis + D` .
134155 - If ``axis`` is None, variance is calculated over all elements of ``x``. Default is None.
135156
136157 unbiased (bool, optional): Whether to use the unbiased estimation. If ``unbiased`` is True, the divisor used in the computation is :math:`N - 1`, where :math:`N` represents the number of elements along ``axis`` , otherwise the divisor is :math:`N`. Default is True.
137158 keep_dim (bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the input unless keep_dim is true. Default is False.
138- name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
159+ name (str|None , optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
139160
140161 Returns:
141162 Tensor, results of variance along ``axis`` of ``x``, with the same data type as ``x``.
@@ -174,13 +195,19 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
174195 return out
175196
176197
177- def std (x , axis = None , unbiased = True , keepdim = False , name = None ):
198+ def std (
199+ x : Tensor ,
200+ axis : int | Sequence [int ] | None = None ,
201+ unbiased : bool = True ,
202+ keepdim : bool = False ,
203+ name : str | None = None ,
204+ ) -> Tensor :
178205 """
179206 Computes the standard-deviation of ``x`` along ``axis`` .
180207
181208 Args:
182209 x (Tensor): The input Tensor with data type float16, float32, float64.
183- axis (int|list|tuple, optional): The axis along which to perform
210+ axis (int|list|tuple|None , optional): The axis along which to perform
184211 standard-deviation calculations. ``axis`` should be int, list(int)
185212 or tuple(int). If ``axis`` is a list/tuple of dimension(s),
186213 standard-deviation is calculated along all element(s) of ``axis`` .
@@ -200,7 +227,7 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None):
200227 the output Tensor is the same as ``x`` except in the reduced
201228 dimensions(it is of size 1 in this case). Otherwise, the shape of
202229 the output Tensor is squeezed in ``axis`` . Default is False.
203- name (str, optional): Name for the operation (optional, default is None).
230+ name (str|None , optional): Name for the operation (optional, default is None).
204231 For more information, please refer to :ref:`api_guide_Name`.
205232
206233 Returns:
@@ -232,13 +259,13 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None):
232259 return paddle .sqrt (out )
233260
234261
235- def numel (x , name = None ):
262+ def numel (x : Tensor , name : str | None = None ) -> Tensor :
236263 """
237264 Returns the number of elements for a tensor, which is a 0-D int64 Tensor with shape [].
238265
239266 Args:
240267 x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64, complex64, complex128.
241- name (str, optional): Name for the operation (optional, default is None).
268+ name (str|None , optional): Name for the operation (optional, default is None).
242269 For more information, please refer to :ref:`api_guide_Name`.
243270
244271 Returns:
@@ -269,7 +296,35 @@ def numel(x, name=None):
269296 return out
270297
271298
272- def nanmedian (x , axis = None , keepdim = False , mode = 'avg' , name = None ):
299+ @overload
300+ def nanmedian (
301+ x : Tensor ,
302+ axis : int ,
303+ keepdim : bool = ...,
304+ mode : Literal ['min' ] = ...,
305+ name : str | None = ...,
306+ ) -> tuple [Tensor , Tensor ]:
307+ ...
308+
309+
310+ @overload
311+ def nanmedian (
312+ x : Tensor ,
313+ axis : int | Sequence [int ] | None = ...,
314+ keepdim : bool = ...,
315+ mode : Literal ['avg' , 'min' ] = ...,
316+ name : str | None = ...,
317+ ) -> Tensor :
318+ ...
319+
320+
321+ def nanmedian (
322+ x ,
323+ axis = None ,
324+ keepdim = False ,
325+ mode = 'avg' ,
326+ name = None ,
327+ ):
273328 r"""
274329 Compute the median along the specified axis, while ignoring NaNs.
275330
@@ -291,7 +346,7 @@ def nanmedian(x, axis=None, keepdim=False, mode='avg', name=None):
291346 mode (str, optional): Whether to use mean or min operation to calculate
292347 the nanmedian values when the input tensor has an even number of non-NaN elements
293348 along the dimension ``axis``. Support 'avg' and 'min'. Default is 'avg'.
294- name (str, optional): Name for the operation (optional, default is None).
349+ name (str|None , optional): Name for the operation (optional, default is None).
295350 For more information, please refer to :ref:`api_guide_Name`.
296351
297352 Returns:
@@ -386,13 +441,41 @@ def nanmedian(x, axis=None, keepdim=False, mode='avg', name=None):
386441 return out
387442
388443
389- def median (x , axis = None , keepdim = False , mode = 'avg' , name = None ):
444+ @overload
445+ def median (
446+ x : Tensor ,
447+ axis : int = ...,
448+ keepdim : bool = ...,
449+ mode : Literal ['min' ] = ...,
450+ name : str | None = ...,
451+ ) -> tuple [Tensor , Tensor ]:
452+ ...
453+
454+
455+ @overload
456+ def median (
457+ x : Tensor ,
458+ axis : int | None = ...,
459+ keepdim : bool = ...,
460+ mode : Literal ['avg' , 'min' ] = ...,
461+ name : str | None = ...,
462+ ) -> Tensor :
463+ ...
464+
465+
466+ def median (
467+ x ,
468+ axis = None ,
469+ keepdim = False ,
470+ mode = 'avg' ,
471+ name = None ,
472+ ):
390473 """
391474 Compute the median along the specified axis.
392475
393476 Args:
394477 x (Tensor): The input Tensor, it's data type can be float16, float32, float64, int32, int64.
395- axis (int, optional): The axis along which to perform median calculations ``axis`` should be int.
478+ axis (int|None , optional): The axis along which to perform median calculations ``axis`` should be int.
396479 ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
397480 If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
398481 If ``axis`` is None, median is calculated over all elements of ``x``. Default is None.
@@ -404,7 +487,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None):
404487 mode (str, optional): Whether to use mean or min operation to calculate
405488 the median values when the input tensor has an even number of elements
406489 in the dimension ``axis``. Support 'avg' and 'min'. Default is 'avg'.
407- name (str, optional): Name for the operation (optional, default is None).
490+ name (str|None , optional): Name for the operation (optional, default is None).
408491 For more information, please refer to :ref:`api_guide_Name`.
409492
410493 Returns:
@@ -611,8 +694,13 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None):
611694
612695
613696def _compute_quantile (
614- x , q , axis = None , keepdim = False , interpolation = "linear" , ignore_nan = False
615- ):
697+ x : Tensor ,
698+ q : float | Sequence [float ] | Tensor | None ,
699+ axis : int | list [int ] | None = None ,
700+ keepdim : bool = False ,
701+ interpolation : _Interpolation = "linear" ,
702+ ignore_nan : bool = False ,
703+ ) -> Tensor :
616704 """
617705 Compute the quantile of the input along the specified axis.
618706
@@ -787,7 +875,13 @@ def _compute_index(index):
787875 return outputs
788876
789877
790- def quantile (x , q , axis = None , keepdim = False , interpolation = "linear" ):
878+ def quantile (
879+ x : Tensor ,
880+ q : float | Sequence [float ] | Tensor ,
881+ axis : int | list [int ] | None = None ,
882+ keepdim : bool = False ,
883+ interpolation : _Interpolation = "linear" ,
884+ ) -> Tensor :
791885 """
792886 Compute the quantile of the input along the specified axis.
793887 If any values in a reduced row are NaN, then the quantiles for that reduction will be NaN.
@@ -865,7 +959,13 @@ def quantile(x, q, axis=None, keepdim=False, interpolation="linear"):
865959 )
866960
867961
868- def nanquantile (x , q , axis = None , keepdim = False , interpolation = "linear" ):
962+ def nanquantile (
963+ x : Tensor ,
964+ q : float | Sequence [float ] | Tensor ,
965+ axis : list [int ] | int = None ,
966+ keepdim : bool = False ,
967+ interpolation : _Interpolation = "linear" ,
968+ ) -> Tensor :
869969 """
870970 Compute the quantile of the input as if NaN values in input did not exist.
871971 If all values in a reduced row are NaN, then the quantiles for that reduction will be NaN.
@@ -888,7 +988,7 @@ def nanquantile(x, q, axis=None, keepdim=False, interpolation="linear"):
888988 interpolation (str, optional): The interpolation method to use
889989 when the desired quantile falls between two data points. Must be one of linear, higher,
890990 lower, midpoint and nearest. Default is linear.
891- name (str, optional): Name for the operation (optional, default is None).
991+ name (str|None , optional): Name for the operation (optional, default is None).
892992 For more information, please refer to :ref:`api_guide_Name`.
893993
894994 Returns:
0 commit comments