1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from __future__ import annotations
1415
1516from collections .abc import Sequence
17+ from typing import TYPE_CHECKING
1618
1719import paddle
20+ from paddle .base .data_feeder import convert_dtype
1821from paddle .distribution import distribution
1922
23+ if TYPE_CHECKING :
24+ from paddle import Tensor
25+ from paddle ._typing .dtype_like import _DTypeLiteral
26+
2027
2128class Poisson (distribution .Distribution ):
2229 r"""
@@ -72,7 +79,10 @@ class Poisson(distribution.Distribution):
7279 [0.06825157 , 1.53426421 ]])
7380 """
7481
75- def __init__ (self , rate ):
82+ rate : Tensor
83+ dtype : _DTypeLiteral
84+
85+ def __init__ (self , rate : float | Tensor ) -> None :
7686 self .dtype = paddle .get_default_dtype ()
7787 self .rate = self ._to_tensor (rate )
7888
@@ -82,7 +92,7 @@ def __init__(self, rate):
8292 batch_shape = self .rate .shape
8393 super ().__init__ (batch_shape )
8494
85- def _to_tensor (self , rate ) :
95+ def _to_tensor (self , rate : float | Tensor ) -> Tensor :
8696 """Convert the input parameters into tensors.
8797
8898 Returns:
@@ -92,11 +102,11 @@ def _to_tensor(self, rate):
92102 if isinstance (rate , (float , int )):
93103 rate = paddle .to_tensor ([rate ], dtype = self .dtype )
94104 else :
95- self .dtype = rate .dtype
105+ self .dtype = convert_dtype ( rate .dtype )
96106 return rate
97107
98108 @property
99- def mean (self ):
109+ def mean (self ) -> Tensor :
100110 """Mean of poisson distribution.
101111
102112 Returns:
@@ -105,15 +115,15 @@ def mean(self):
105115 return self .rate
106116
107117 @property
108- def variance (self ):
118+ def variance (self ) -> Tensor :
109119 """Variance of poisson distribution.
110120
111121 Returns:
112122 Tensor: variance value.
113123 """
114124 return self .rate
115125
116- def sample (self , shape = ()):
126+ def sample (self , shape : Sequence [ int ] = ()) -> Tensor :
117127 """Generate poisson samples of the specified shape. The final shape would be ``shape+batch_shape`` .
118128
119129 Args:
@@ -133,7 +143,7 @@ def sample(self, shape=()):
133143 with paddle .no_grad ():
134144 return paddle .poisson (output_rate )
135145
136- def entropy (self ):
146+ def entropy (self ) -> Tensor :
137147 r"""Shannon entropy in nats.
138148
139149 The entropy is
@@ -162,7 +172,7 @@ def entropy(self):
162172 )
163173 return paddle .multiply (proposed , mask )
164174
165- def _enumerate_bounded_support (self , rate ) :
175+ def _enumerate_bounded_support (self , rate : float | Tensor ) -> Tensor :
166176 """Generate a bounded approximation of the support. Approximately view Poisson r.v. as a
167177 Normal r.v. with mu = rate and sigma = sqrt(rate). Then by 30-sigma rule, generate a bounded
168178 approximation of the support.
@@ -203,7 +213,7 @@ def false_func():
203213 values = paddle .arange (0 , upper , dtype = self .dtype )
204214 return values
205215
206- def log_prob (self , value ) :
216+ def log_prob (self , value : Tensor ) -> Tensor :
207217 """Log probability density/mass function.
208218
209219 Args:
@@ -223,7 +233,7 @@ def log_prob(self, value):
223233 neginf = - eps ,
224234 )
225235
226- def prob (self , value ) :
236+ def prob (self , value : Tensor ) -> Tensor :
227237 """Probability density/mass function.
228238
229239 Args:
@@ -234,7 +244,7 @@ def prob(self, value):
234244 """
235245 return paddle .exp (self .log_prob (value ))
236246
237- def kl_divergence (self , other ) :
247+ def kl_divergence (self , other : Poisson ) -> Tensor :
238248 r"""The KL-divergence between two poisson distributions with the same `batch_shape`.
239249
240250 The probability density function (pdf) is
0 commit comments