Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions python/paddle/distribution/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

import numpy as np

Expand All @@ -22,6 +25,9 @@
from paddle.distribution import exponential_family
from paddle.framework import in_dynamic_mode

if TYPE_CHECKING:
from paddle import Tensor, dtype


class Exponential(exponential_family.ExponentialFamily):
r"""
Expand Down Expand Up @@ -59,7 +65,10 @@ class Exponential(exponential_family.ExponentialFamily):
[1.69314718])
"""

def __init__(self, rate):
rate: Tensor
dtype: dtype

def __init__(self, rate: float | Tensor) -> None:
if not in_dynamic_mode():
check_type(
rate,
Expand All @@ -79,7 +88,7 @@ def __init__(self, rate):
super().__init__(self.rate.shape)

@property
def mean(self):
def mean(self) -> Tensor:
"""Mean of exponential distribution.

Returns:
Expand All @@ -88,15 +97,15 @@ def mean(self):
return self.rate.reciprocal()

@property
def variance(self):
def variance(self) -> Tensor:
"""Variance of exponential distribution.

Returns:
Tensor: variance value.
"""
return self.rate.pow(-2)

def sample(self, shape=()):
def sample(self, shape: Sequence[int] = ()) -> Tensor:
"""Generate samples of the specified shape.

Args:
Expand All @@ -108,7 +117,7 @@ def sample(self, shape=()):
with paddle.no_grad():
return self.rsample(shape)

def rsample(self, shape=()):
def rsample(self, shape: Sequence[int] = ()) -> Tensor:
"""Generate reparameterized samples of the specified shape.

Args:
Expand All @@ -130,7 +139,7 @@ def rsample(self, shape=()):

return -paddle.log(uniform) / self.rate

def prob(self, value):
def prob(self, value: float | Tensor) -> Tensor:
r"""Probability density function evaluated at value.

.. math::
Expand All @@ -145,7 +154,7 @@ def prob(self, value):
"""
return self.rate * paddle.exp(-self.rate * value)

def log_prob(self, value):
def log_prob(self, value: float | Tensor) -> Tensor:
"""Log probability density function evaluated at value.

Args:
Expand All @@ -156,15 +165,15 @@ def log_prob(self, value):
"""
return paddle.log(self.rate) - self.rate * value

def entropy(self):
def entropy(self) -> Tensor:
"""Entropy of exponential distribution.

Returns:
Tensor: Entropy.
"""
return 1.0 - paddle.log(self.rate)

def cdf(self, value):
def cdf(self, value: float | Tensor) -> Tensor:
r"""Cumulative distribution function(CDF) evaluated at value.

.. math::
Expand All @@ -180,7 +189,7 @@ def cdf(self, value):
"""
return 1.0 - paddle.exp(-self.rate * value)

def icdf(self, value):
def icdf(self, value: float | Tensor) -> Tensor:
r"""Inverse cumulative distribution function(CDF) evaluated at value.

.. math::
Expand All @@ -196,7 +205,7 @@ def icdf(self, value):
"""
return -paddle.log1p(-value) / self.rate

def kl_divergence(self, other):
def kl_divergence(self, other: Exponential) -> Tensor:
"""The KL-divergence between two exponential distributions.

Args:
Expand All @@ -215,8 +224,8 @@ def kl_divergence(self, other):
return t1 + rate_ratio - 1

@property
def _natural_parameters(self):
def _natural_parameters(self) -> tuple[Tensor]:
return (-self.rate,)

def _log_normalizer(self, x):
def _log_normalizer(self, x: Tensor) -> Tensor:
return -paddle.log(-x)