Skip to content

Commit 4724f5f

Browse files
authored
[Typing][B-14] Add type annotations for python/paddle/distribution/geometric.py (#65773)
1 parent 3477632 commit 4724f5f

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

python/paddle/distribution/geometric.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import numbers
18+
from typing import TYPE_CHECKING, Sequence
1619

1720
import numpy as np
1821

1922
import paddle
2023
from paddle.base import framework
2124
from paddle.distribution import distribution
2225

26+
if TYPE_CHECKING:
27+
from paddle import Tensor
28+
2329

2430
class Geometric(distribution.Distribution):
2531
r"""
@@ -67,7 +73,9 @@ class Geometric(distribution.Distribution):
6773
1.41421354)
6874
"""
6975

70-
def __init__(self, probs):
76+
probs: Tensor
77+
78+
def __init__(self, probs: float | Tensor) -> None:
7179
if isinstance(
7280
probs,
7381
(numbers.Real, paddle.Tensor, framework.Variable, paddle.pir.Value),
@@ -101,24 +109,24 @@ def __init__(self, probs):
101109
super().__init__(batch_shape)
102110

103111
@property
104-
def mean(self):
112+
def mean(self) -> Tensor:
105113
"""Mean of geometric distribution."""
106114
return 1.0 / self.probs - 1.0
107115

108116
@property
109-
def variance(self):
117+
def variance(self) -> Tensor:
110118
"""Variance of geometric distribution."""
111119
return paddle.to_tensor(
112120
(1.0 / self.probs - 1.0) / self.probs,
113121
dtype=self.probs.dtype,
114122
)
115123

116124
@property
117-
def stddev(self):
125+
def stddev(self) -> Tensor:
118126
"""Standard deviation of Geometric distribution."""
119127
return paddle.sqrt(self.variance)
120128

121-
def pmf(self, k):
129+
def pmf(self, k: int | Tensor) -> Tensor:
122130
r"""Probability mass function evaluated at k.
123131
124132
.. math::
@@ -152,7 +160,7 @@ def pmf(self, k):
152160
f"Expected type of k is number.Real|framework.Variable|Value, but got {type(k)}"
153161
)
154162

155-
def log_pmf(self, k):
163+
def log_pmf(self, k: int | Tensor) -> Tensor:
156164
r"""Log probability mass function evaluated at k.
157165
158166
.. math::
@@ -185,11 +193,11 @@ def log_pmf(self, k):
185193
f"Expected type of k is number.Real|framework.Variable|Value, but got {type(k)}"
186194
)
187195

188-
def sample(self, shape=()):
196+
def sample(self, shape: Sequence[int] = ()) -> Tensor:
189197
"""Sample from Geometric distribution with sample shape.
190198
191199
Args:
192-
shape (tuple(int)): Sample shape.
200+
shape (Sequence[int]): Sample shape.
193201
194202
Returns:
195203
Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
@@ -211,11 +219,11 @@ def sample(self, shape=()):
211219
with paddle.no_grad():
212220
return self.rsample(shape)
213221

214-
def rsample(self, shape=()):
222+
def rsample(self, shape: Sequence[int] = ()) -> Tensor:
215223
"""Generate samples of the specified shape.
216224
217225
Args:
218-
shape(tuple(int)): The shape of generated samples.
226+
shape(Sequence[int]): The shape of generated samples.
219227
220228
Returns:
221229
Tensor: A sample tensor that fits the Geometric distribution.
@@ -248,7 +256,7 @@ def rsample(self, shape=()):
248256

249257
return paddle.floor(paddle.log(uniform) / paddle.log1p(-(self.probs)))
250258

251-
def entropy(self):
259+
def entropy(self) -> Tensor:
252260
r"""Entropy of dirichlet distribution.
253261
254262
.. math::
@@ -275,7 +283,7 @@ def entropy(self):
275283

276284
return -(x + y) / self.probs
277285

278-
def cdf(self, k):
286+
def cdf(self, k: int | Tensor) -> Tensor:
279287
r"""Cdf of geometric distribution.
280288
281289
.. math::
@@ -309,7 +317,7 @@ def cdf(self, k):
309317
f"Expected type of k is number.Real|framework.Variable|Value, but got {type(k)}"
310318
)
311319

312-
def kl_divergence(self, other):
320+
def kl_divergence(self, other: Geometric) -> Tensor:
313321
r"""Calculate the KL divergence KL(self || other) with two Geometric instances.
314322
315323
.. math::

0 commit comments

Comments
 (0)