Skip to content

Commit 0bbbaf4

Browse files
enkileeSigureMo
authored andcommitted
[Typing][B-24] Add type annotations for python/paddle/distribution/poisson.py (PaddlePaddle#65852)
--------- Co-authored-by: SigureMo <[email protected]>
1 parent c00e827 commit 0bbbaf4

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

python/paddle/distribution/poisson.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
from collections.abc import Sequence
17+
from typing import TYPE_CHECKING
1618

1719
import paddle
20+
from paddle.base.data_feeder import convert_dtype
1821
from paddle.distribution import distribution
1922

23+
if TYPE_CHECKING:
24+
from paddle import Tensor
25+
from paddle._typing.dtype_like import _DTypeLiteral
26+
2027

2128
class Poisson(distribution.Distribution):
2229
r"""
@@ -72,7 +79,10 @@ class Poisson(distribution.Distribution):
7279
[0.06825157 , 1.53426421 ]])
7380
"""
7481

75-
def __init__(self, rate):
82+
rate: Tensor
83+
dtype: _DTypeLiteral
84+
85+
def __init__(self, rate: float | Tensor) -> None:
7686
self.dtype = paddle.get_default_dtype()
7787
self.rate = self._to_tensor(rate)
7888

@@ -82,7 +92,7 @@ def __init__(self, rate):
8292
batch_shape = self.rate.shape
8393
super().__init__(batch_shape)
8494

85-
def _to_tensor(self, rate):
95+
def _to_tensor(self, rate: float | Tensor) -> Tensor:
8696
"""Convert the input parameters into tensors.
8797
8898
Returns:
@@ -92,11 +102,11 @@ def _to_tensor(self, rate):
92102
if isinstance(rate, (float, int)):
93103
rate = paddle.to_tensor([rate], dtype=self.dtype)
94104
else:
95-
self.dtype = rate.dtype
105+
self.dtype = convert_dtype(rate.dtype)
96106
return rate
97107

98108
@property
99-
def mean(self):
109+
def mean(self) -> Tensor:
100110
"""Mean of poisson distribution.
101111
102112
Returns:
@@ -105,15 +115,15 @@ def mean(self):
105115
return self.rate
106116

107117
@property
108-
def variance(self):
118+
def variance(self) -> Tensor:
109119
"""Variance of poisson distribution.
110120
111121
Returns:
112122
Tensor: variance value.
113123
"""
114124
return self.rate
115125

116-
def sample(self, shape=()):
126+
def sample(self, shape: Sequence[int] = ()) -> Tensor:
117127
"""Generate poisson samples of the specified shape. The final shape would be ``shape+batch_shape`` .
118128
119129
Args:
@@ -133,7 +143,7 @@ def sample(self, shape=()):
133143
with paddle.no_grad():
134144
return paddle.poisson(output_rate)
135145

136-
def entropy(self):
146+
def entropy(self) -> Tensor:
137147
r"""Shannon entropy in nats.
138148
139149
The entropy is
@@ -162,7 +172,7 @@ def entropy(self):
162172
)
163173
return paddle.multiply(proposed, mask)
164174

165-
def _enumerate_bounded_support(self, rate):
175+
def _enumerate_bounded_support(self, rate: float | Tensor) -> Tensor:
166176
"""Generate a bounded approximation of the support. Approximately view Poisson r.v. as a
167177
Normal r.v. with mu = rate and sigma = sqrt(rate). Then by 30-sigma rule, generate a bounded
168178
approximation of the support.
@@ -203,7 +213,7 @@ def false_func():
203213
values = paddle.arange(0, upper, dtype=self.dtype)
204214
return values
205215

206-
def log_prob(self, value):
216+
def log_prob(self, value: Tensor) -> Tensor:
207217
"""Log probability density/mass function.
208218
209219
Args:
@@ -223,7 +233,7 @@ def log_prob(self, value):
223233
neginf=-eps,
224234
)
225235

226-
def prob(self, value):
236+
def prob(self, value: Tensor) -> Tensor:
227237
"""Probability density/mass function.
228238
229239
Args:
@@ -234,7 +244,7 @@ def prob(self, value):
234244
"""
235245
return paddle.exp(self.log_prob(value))
236246

237-
def kl_divergence(self, other):
247+
def kl_divergence(self, other: Poisson) -> Tensor:
238248
r"""The KL-divergence between two poisson distributions with the same `batch_shape`.
239249
240250
The probability density function (pdf) is

0 commit comments

Comments
 (0)