1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
1517import numbers
18+ from typing import TYPE_CHECKING , Sequence
1619
1720import numpy as np
1821
1922import paddle
2023from paddle .base import framework
2124from paddle .distribution import distribution
2225
26+ if TYPE_CHECKING :
27+ from paddle import Tensor
28+
2329
2430class Geometric (distribution .Distribution ):
2531 r"""
@@ -67,7 +73,9 @@ class Geometric(distribution.Distribution):
6773 1.41421354)
6874 """
6975
70- def __init__ (self , probs ):
76+ probs : Tensor
77+
78+ def __init__ (self , probs : float | Tensor ) -> None :
7179 if isinstance (
7280 probs ,
7381 (numbers .Real , paddle .Tensor , framework .Variable , paddle .pir .Value ),
@@ -101,24 +109,24 @@ def __init__(self, probs):
101109 super ().__init__ (batch_shape )
102110
103111 @property
104- def mean (self ):
112+ def mean (self ) -> Tensor :
105113 """Mean of geometric distribution."""
106114 return 1.0 / self .probs - 1.0
107115
108116 @property
109- def variance (self ):
117+ def variance (self ) -> Tensor :
110118 """Variance of geometric distribution."""
111119 return paddle .to_tensor (
112120 (1.0 / self .probs - 1.0 ) / self .probs ,
113121 dtype = self .probs .dtype ,
114122 )
115123
116124 @property
117- def stddev (self ):
125+ def stddev (self ) -> Tensor :
118126 """Standard deviation of Geometric distribution."""
119127 return paddle .sqrt (self .variance )
120128
121- def pmf (self , k ) :
129+ def pmf (self , k : int | Tensor ) -> Tensor :
122130 r"""Probability mass function evaluated at k.
123131
124132 .. math::
@@ -152,7 +160,7 @@ def pmf(self, k):
152160 f"Expected type of k is number.Real|framework.Variable|Value, but got { type (k )} "
153161 )
154162
155- def log_pmf (self , k ) :
163+ def log_pmf (self , k : int | Tensor ) -> Tensor :
156164 r"""Log probability mass function evaluated at k.
157165
158166 .. math::
@@ -185,11 +193,11 @@ def log_pmf(self, k):
185193 f"Expected type of k is number.Real|framework.Variable|Value, but got { type (k )} "
186194 )
187195
188- def sample (self , shape = ()):
196+ def sample (self , shape : Sequence [ int ] = ()) -> Tensor :
189197 """Sample from Geometric distribution with sample shape.
190198
191199 Args:
192- shape (tuple( int) ): Sample shape.
200+ shape (Sequence[ int] ): Sample shape.
193201
194202 Returns:
195203 Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
@@ -211,11 +219,11 @@ def sample(self, shape=()):
211219 with paddle .no_grad ():
212220 return self .rsample (shape )
213221
214- def rsample (self , shape = ()):
222+ def rsample (self , shape : Sequence [ int ] = ()) -> Tensor :
215223 """Generate samples of the specified shape.
216224
217225 Args:
218- shape(tuple( int) ): The shape of generated samples.
226+ shape(Sequence[ int] ): The shape of generated samples.
219227
220228 Returns:
221229 Tensor: A sample tensor that fits the Geometric distribution.
@@ -248,7 +256,7 @@ def rsample(self, shape=()):
248256
249257 return paddle .floor (paddle .log (uniform ) / paddle .log1p (- (self .probs )))
250258
251- def entropy (self ):
259+ def entropy (self ) -> Tensor :
252260 r"""Entropy of dirichlet distribution.
253261
254262 .. math::
@@ -275,7 +283,7 @@ def entropy(self):
275283
276284 return - (x + y ) / self .probs
277285
278- def cdf (self , k ) :
286+ def cdf (self , k : int | Tensor ) -> Tensor :
279287 r"""Cdf of geometric distribution.
280288
281289 .. math::
@@ -309,7 +317,7 @@ def cdf(self, k):
309317 f"Expected type of k is number.Real|framework.Variable|Value, but got { type (k )} "
310318 )
311319
312- def kl_divergence (self , other ) :
320+ def kl_divergence (self , other : Geometric ) -> Tensor :
313321 r"""Calculate the KL divergence KL(self || other) with two Geometric instances.
314322
315323 .. math::
0 commit comments