Skip to content

Commit 7d9b558

Browse files
ooooo-createSigureMo
authored andcommitted
[Typing][A-12] Add type annotations for paddle/tensor/stat.py (PaddlePaddle#65337)
--------- Co-authored-by: Nyakku Shigure <[email protected]>
1 parent bb12578 commit 7d9b558

File tree

1 file changed

+122
-22
lines changed

1 file changed

+122
-22
lines changed

python/paddle/tensor/stat.py

Lines changed: 122 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# TODO: define statistical functions of a tensor
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING, Literal, Sequence, overload
18+
19+
from typing_extensions import TypeAlias
1620

1721
import paddle
1822
from paddle import _C_ops
@@ -27,16 +31,27 @@
2731
from .math import _get_reduce_axis_with_tensor
2832
from .search import where
2933

34+
if TYPE_CHECKING:
35+
from paddle import Tensor
36+
37+
_Interpolation: TypeAlias = Literal[
38+
'linear', 'higher', 'lower', 'midpoint', 'nearest'
39+
]
3040
__all__ = []
3141

3242

33-
def mean(x, axis=None, keepdim=False, name=None):
43+
def mean(
44+
x: Tensor,
45+
axis: int | Sequence[int] | None = None,
46+
keepdim: bool = False,
47+
name: str | None = None,
48+
) -> Tensor:
3449
"""
3550
Computes the mean of the input tensor's elements along ``axis``.
3651
3752
Args:
3853
x (Tensor): The input Tensor with data type float32, float64.
39-
axis (int|list|tuple, optional): The axis along which to perform mean
54+
axis (int|list|tuple|None, optional): The axis along which to perform mean
4055
calculations. ``axis`` should be int, list(int) or tuple(int). If
4156
``axis`` is a list/tuple of dimension(s), mean is calculated along
4257
all element(s) of ``axis`` . ``axis`` or element(s) of ``axis``
@@ -49,7 +64,7 @@ def mean(x, axis=None, keepdim=False, name=None):
4964
the output Tensor is the same as ``x`` except in the reduced
5065
dimensions(it is of size 1 in this case). Otherwise, the shape of
5166
the output Tensor is squeezed in ``axis`` . Default is False.
52-
name (str, optional): Name for the operation (optional, default is None).
67+
name (str|None, optional): Name for the operation (optional, default is None).
5368
For more information, please refer to :ref:`api_guide_Name`.
5469
5570
Returns:
@@ -121,21 +136,27 @@ def mean(x, axis=None, keepdim=False, name=None):
121136
return out
122137

123138

124-
def var(x, axis=None, unbiased=True, keepdim=False, name=None):
139+
def var(
140+
x: Tensor,
141+
axis: int | Sequence[int] | None = None,
142+
unbiased: bool = True,
143+
keepdim: bool = False,
144+
name: str | None = None,
145+
) -> Tensor:
125146
"""
126147
Computes the variance of ``x`` along ``axis`` .
127148
128149
Args:
129150
x (Tensor): The input Tensor with data type float16, float32, float64.
130-
axis (int|list|tuple, optional): The axis along which to perform variance calculations. ``axis`` should be int, list(int) or tuple(int).
151+
axis (int|list|tuple|None, optional): The axis along which to perform variance calculations. ``axis`` should be int, list(int) or tuple(int).
131152
132153
- If ``axis`` is a list/tuple of dimension(s), variance is calculated along all element(s) of ``axis`` . ``axis`` or element(s) of ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
133154
- If ``axis`` or element(s) of ``axis`` is less than 0, it works the same way as :math:`axis + D` .
134155
- If ``axis`` is None, variance is calculated over all elements of ``x``. Default is None.
135156
136157
unbiased (bool, optional): Whether to use the unbiased estimation. If ``unbiased`` is True, the divisor used in the computation is :math:`N - 1`, where :math:`N` represents the number of elements along ``axis`` , otherwise the divisor is :math:`N`. Default is True.
137158
keep_dim (bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the input unless keep_dim is true. Default is False.
138-
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
159+
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
139160
140161
Returns:
141162
Tensor, results of variance along ``axis`` of ``x``, with the same data type as ``x``.
@@ -174,13 +195,19 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
174195
return out
175196

176197

177-
def std(x, axis=None, unbiased=True, keepdim=False, name=None):
198+
def std(
199+
x: Tensor,
200+
axis: int | Sequence[int] | None = None,
201+
unbiased: bool = True,
202+
keepdim: bool = False,
203+
name: str | None = None,
204+
) -> Tensor:
178205
"""
179206
Computes the standard-deviation of ``x`` along ``axis`` .
180207
181208
Args:
182209
x (Tensor): The input Tensor with data type float16, float32, float64.
183-
axis (int|list|tuple, optional): The axis along which to perform
210+
axis (int|list|tuple|None, optional): The axis along which to perform
184211
standard-deviation calculations. ``axis`` should be int, list(int)
185212
or tuple(int). If ``axis`` is a list/tuple of dimension(s),
186213
standard-deviation is calculated along all element(s) of ``axis`` .
@@ -200,7 +227,7 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None):
200227
the output Tensor is the same as ``x`` except in the reduced
201228
dimensions(it is of size 1 in this case). Otherwise, the shape of
202229
the output Tensor is squeezed in ``axis`` . Default is False.
203-
name (str, optional): Name for the operation (optional, default is None).
230+
name (str|None, optional): Name for the operation (optional, default is None).
204231
For more information, please refer to :ref:`api_guide_Name`.
205232
206233
Returns:
@@ -232,13 +259,13 @@ def std(x, axis=None, unbiased=True, keepdim=False, name=None):
232259
return paddle.sqrt(out)
233260

234261

235-
def numel(x, name=None):
262+
def numel(x: Tensor, name: str | None = None) -> Tensor:
236263
"""
237264
Returns the number of elements for a tensor, which is a 0-D int64 Tensor with shape [].
238265
239266
Args:
240267
x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64, complex64, complex128.
241-
name (str, optional): Name for the operation (optional, default is None).
268+
name (str|None, optional): Name for the operation (optional, default is None).
242269
For more information, please refer to :ref:`api_guide_Name`.
243270
244271
Returns:
@@ -269,7 +296,35 @@ def numel(x, name=None):
269296
return out
270297

271298

272-
def nanmedian(x, axis=None, keepdim=False, mode='avg', name=None):
299+
@overload
300+
def nanmedian(
301+
x: Tensor,
302+
axis: int,
303+
keepdim: bool = ...,
304+
mode: Literal['min'] = ...,
305+
name: str | None = ...,
306+
) -> tuple[Tensor, Tensor]:
307+
...
308+
309+
310+
@overload
311+
def nanmedian(
312+
x: Tensor,
313+
axis: int | Sequence[int] | None = ...,
314+
keepdim: bool = ...,
315+
mode: Literal['avg', 'min'] = ...,
316+
name: str | None = ...,
317+
) -> Tensor:
318+
...
319+
320+
321+
def nanmedian(
322+
x,
323+
axis=None,
324+
keepdim=False,
325+
mode='avg',
326+
name=None,
327+
):
273328
r"""
274329
Compute the median along the specified axis, while ignoring NaNs.
275330
@@ -291,7 +346,7 @@ def nanmedian(x, axis=None, keepdim=False, mode='avg', name=None):
291346
mode (str, optional): Whether to use mean or min operation to calculate
292347
the nanmedian values when the input tensor has an even number of non-NaN elements
293348
along the dimension ``axis``. Support 'avg' and 'min'. Default is 'avg'.
294-
name (str, optional): Name for the operation (optional, default is None).
349+
name (str|None, optional): Name for the operation (optional, default is None).
295350
For more information, please refer to :ref:`api_guide_Name`.
296351
297352
Returns:
@@ -386,13 +441,41 @@ def nanmedian(x, axis=None, keepdim=False, mode='avg', name=None):
386441
return out
387442

388443

389-
def median(x, axis=None, keepdim=False, mode='avg', name=None):
444+
@overload
445+
def median(
446+
x: Tensor,
447+
axis: int = ...,
448+
keepdim: bool = ...,
449+
mode: Literal['min'] = ...,
450+
name: str | None = ...,
451+
) -> tuple[Tensor, Tensor]:
452+
...
453+
454+
455+
@overload
456+
def median(
457+
x: Tensor,
458+
axis: int | None = ...,
459+
keepdim: bool = ...,
460+
mode: Literal['avg', 'min'] = ...,
461+
name: str | None = ...,
462+
) -> Tensor:
463+
...
464+
465+
466+
def median(
467+
x,
468+
axis=None,
469+
keepdim=False,
470+
mode='avg',
471+
name=None,
472+
):
390473
"""
391474
Compute the median along the specified axis.
392475
393476
Args:
394477
x (Tensor): The input Tensor, it's data type can be float16, float32, float64, int32, int64.
395-
axis (int, optional): The axis along which to perform median calculations ``axis`` should be int.
478+
axis (int|None, optional): The axis along which to perform median calculations ``axis`` should be int.
396479
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
397480
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
398481
If ``axis`` is None, median is calculated over all elements of ``x``. Default is None.
@@ -404,7 +487,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None):
404487
mode (str, optional): Whether to use mean or min operation to calculate
405488
the median values when the input tensor has an even number of elements
406489
in the dimension ``axis``. Support 'avg' and 'min'. Default is 'avg'.
407-
name (str, optional): Name for the operation (optional, default is None).
490+
name (str|None, optional): Name for the operation (optional, default is None).
408491
For more information, please refer to :ref:`api_guide_Name`.
409492
410493
Returns:
@@ -611,8 +694,13 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None):
611694

612695

613696
def _compute_quantile(
614-
x, q, axis=None, keepdim=False, interpolation="linear", ignore_nan=False
615-
):
697+
x: Tensor,
698+
q: float | Sequence[float] | Tensor | None,
699+
axis: int | list[int] | None = None,
700+
keepdim: bool = False,
701+
interpolation: _Interpolation = "linear",
702+
ignore_nan: bool = False,
703+
) -> Tensor:
616704
"""
617705
Compute the quantile of the input along the specified axis.
618706
@@ -787,7 +875,13 @@ def _compute_index(index):
787875
return outputs
788876

789877

790-
def quantile(x, q, axis=None, keepdim=False, interpolation="linear"):
878+
def quantile(
879+
x: Tensor,
880+
q: float | Sequence[float] | Tensor,
881+
axis: int | list[int] | None = None,
882+
keepdim: bool = False,
883+
interpolation: _Interpolation = "linear",
884+
) -> Tensor:
791885
"""
792886
Compute the quantile of the input along the specified axis.
793887
If any values in a reduced row are NaN, then the quantiles for that reduction will be NaN.
@@ -865,7 +959,13 @@ def quantile(x, q, axis=None, keepdim=False, interpolation="linear"):
865959
)
866960

867961

868-
def nanquantile(x, q, axis=None, keepdim=False, interpolation="linear"):
962+
def nanquantile(
963+
x: Tensor,
964+
q: float | Sequence[float] | Tensor,
965+
axis: list[int] | int = None,
966+
keepdim: bool = False,
967+
interpolation: _Interpolation = "linear",
968+
) -> Tensor:
869969
"""
870970
Compute the quantile of the input as if NaN values in input did not exist.
871971
If all values in a reduced row are NaN, then the quantiles for that reduction will be NaN.
@@ -888,7 +988,7 @@ def nanquantile(x, q, axis=None, keepdim=False, interpolation="linear"):
888988
interpolation (str, optional): The interpolation method to use
889989
when the desired quantile falls between two data points. Must be one of linear, higher,
890990
lower, midpoint and nearest. Default is linear.
891-
name (str, optional): Name for the operation (optional, default is None).
991+
name (str|None, optional): Name for the operation (optional, default is None).
892992
For more information, please refer to :ref:`api_guide_Name`.
893993
894994
Returns:

0 commit comments

Comments
 (0)