Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,4 @@ There are a few tools to clean and check the source:
- :bash:`mypy`
- :bash:`isort .`
- :bash:`pylint efax tests`

4 changes: 4 additions & 0 deletions efax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from ._src.distributions.geometric import GeometricEP, GeometricNP
from ._src.distributions.inverse_gamma import InverseGammaEP, InverseGammaNP
from ._src.distributions.inverse_gaussian import InverseGaussianEP, InverseGaussianNP
from ._src.distributions.generalized_inverse_gaussian import (GeneralizedInverseGaussianEP,
GeneralizedInverseGaussianNP)
from ._src.distributions.log_normal.log_normal import LogNormalEP, LogNormalNP
from ._src.distributions.log_normal.unit_variance import (UnitVarianceLogNormalEP,
UnitVarianceLogNormalNP)
Expand Down Expand Up @@ -116,6 +118,8 @@
'InverseGammaNP',
'InverseGaussianEP',
'InverseGaussianNP',
'GeneralizedInverseGaussianEP',
'GeneralizedInverseGaussianNP',
'IsotropicNormalEP',
'IsotropicNormalNP',
'JointDistribution',
Expand Down
227 changes: 227 additions & 0 deletions efax/_src/distributions/generalized_inverse_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
from __future__ import annotations

import jax
import jax.lax as lax
import jax.random as jr
import jax.numpy as jnp
from array_api_compat import array_namespace
from tjax import JaxArray, JaxRealArray, KeyArray, Shape
from tjax.dataclasses import dataclass
from typing_extensions import override, Tuple
# from logbesselk.jax import log_bessel_k, bessel_kratio
from ..tools import log_kve
from ..expectation_parametrization import ExpectationParametrization
from ..interfaces.samplable import Samplable
from ..mixins.has_entropy import HasEntropyEP, HasEntropyNP
from ..mixins.exp_to_nat.exp_to_nat import ExpToNat
from ..natural_parametrization import NaturalParametrization
from ..parameter import (RealField, ScalarSupport, Support, distribution_parameter,
negative_support, positive_support)
from ..parametrization import SimpleDistribution


@dataclass
class GeneralizedInverseGaussianNP(HasEntropyNP['GeneralizedInverseGaussianEP'],
Samplable,
NaturalParametrization['GeneralizedInverseGaussianEP', JaxRealArray],
SimpleDistribution):
"""The natural parametrization of the generalized inverse Gaussian distribution.

Args:
p_minus_one: p - 1, where p is the shape parameter
negative_a_over_two: -a/2, where a is the first scale parameter
negative_b_over_two: -b/2, where b is the second scale parameter
"""
p_minus_one: JaxRealArray = distribution_parameter(ScalarSupport())
negative_a_over_two: JaxRealArray = distribution_parameter(ScalarSupport(ring=negative_support))
negative_b_over_two: JaxRealArray = distribution_parameter(ScalarSupport(ring=negative_support))

@property
@override
def shape(self) -> Shape:
return self.p_minus_one.shape

@override
@classmethod
def domain_support(cls) -> ScalarSupport:
return ScalarSupport(ring=positive_support)

@override
def log_normalizer(self) -> JaxRealArray:
xp = self.array_namespace()
# Convert natural parameters to original parameters
p = self.p_minus_one + 1.0
a = -2.0 * self.negative_a_over_two
b = -2.0 * self.negative_b_over_two

y = xp.sqrt(b / a)
z = xp.sqrt(a * b)

# if p.ndim == 0:
# log_2k = xp.log(2.0) + log_kve(p, z)
# else:
# log_k = jax.jit(jax.vmap(log_bessel_k, 0))
# log_2k = xp.log(2.0) + log_k(p, z)
return xp.log(2.0) + log_kve(p, z) - z + p * xp.log(y)

@override
def to_exp(self) -> GeneralizedInverseGaussianEP:
"""Convert from natural to expectation parameters."""
xp = self.array_namespace()

# Convert natural parameters to original parameters
p = self.p_minus_one + 1.0
a = -2.0 * self.negative_a_over_two
b = -2.0 * self.negative_b_over_two

y = xp.sqrt(b / a)
z = xp.sqrt(a * b)

# TBD: Use naive finite difference for dlogk_dp for now
# as jax.grad does not work with log_bessel_k function
# Issue: https://github.com/tk2lab/logbesselk/issues/33
# mean_log = xp.log(y) + dlogk_dp(p, z)
# eps = 1e-6 would cause numerical issue in log_bessel_k
# if p.ndim == 0:
# eps = 1e-5
# kratio = xp.exp(log_bessel_k(p + 1, z) - log_bessel_k(p, z))
# dlogk_dp = (log_bessel_k(p + eps, z) - log_bessel_k(p - eps, z)) / (2.0 * eps)
# else:
# eps = xp.ones_like(p) * 1e-5
# logk = jax.jit(jax.vmap(log_bessel_k, 0))
# kratio = xp.exp(logk(p + xp.ones_like(p), z) - logk(p, z))
# dlogk_dp = (logk(p + eps, z) - logk(p - eps, z)) / (2.0 * eps)

eps = 1e-10
dlogk_dp = (log_kve(p + eps, z) - log_kve(p - eps, z)) / (2.0 * eps)
kratio = xp.exp(log_kve(p + 1, z) - log_kve(p, z))

mean_log = xp.log(y) + dlogk_dp
mean = kratio * y
mean_inv = (1.0 / y) * kratio - 2.0 * p / b
return GeneralizedInverseGaussianEP(mean_log, mean, mean_inv)

@override
def carrier_measure(self, x: JaxRealArray) -> JaxRealArray:
"""The carrier measure for GIG is h(x) = 1, so log(h(x)) = 0."""
xp = self.array_namespace(x)
return xp.zeros(x.shape)

@override
@classmethod
def sufficient_statistics(cls, x: JaxRealArray, **fixed_parameters: JaxArray
) -> GeneralizedInverseGaussianEP:
"""Compute the sufficient statistics T(x) = (ln x, x, 1/x)."""
xp = array_namespace(x)
log_x = xp.log(x)
inv_x = xp.reciprocal(x)
return GeneralizedInverseGaussianEP(log_x, x, inv_x)

@override
def sample(self, key: KeyArray, shape: Shape, burnin: int = 100) -> JaxRealArray:
"""Sample from the Generalized Inverse Gaussian distribution using Gibbs sampling.

References:
Peña, V., & Jauch, M. (2025). Properties of the generalized inverse Gaussian with applications to Monte Carlo simulation and distribution function evaluation. Statistics & Probability Letters, 110359.
https://github.com/michaeljauch/gig/blob/main/GIG_Gibbs_Benchmark.R
"""
xp = self.array_namespace()

# Convert natural parameters to original parameters
p = self.p_minus_one + 1.0
a = -2.0 * self.negative_a_over_two
b = -2.0 * self.negative_b_over_two

# Set up shape dimensions
full_shape = self.shape if shape is None else shape + self.shape

def gibbs_update(carry: Tuple[JaxRealArray, KeyArray], _):
x_prev, subkey = carry
subkey, gamma_key, ig_key = jr.split(subkey, 3)

# For p < -0.5 case
alpha_neg = -(p + 0.5)
beta_neg = x_prev
y_neg = jr.gamma(gamma_key, alpha_neg, shape=x_prev.shape)
y_neg = y_neg / beta_neg
mu_neg = xp.sqrt(b / (a + 2.0 * y_neg))
lambda_neg = b

# For p >= -0.5 case
alpha_pos = p + 0.5
beta_pos = 1.0 / x_prev
y_pos = jr.gamma(gamma_key, alpha_pos, shape=x_prev.shape)
y_pos = y_pos / beta_pos
b2y_pos = b + 2.0 * y_pos
mu_pos = xp.sqrt(b2y_pos / a)
lambda_pos = b2y_pos

# Select parameters based on p value
mu = xp.where(p < -0.5, mu_neg, mu_pos)
lambda_ = xp.where(p < -0.5, lambda_neg, lambda_pos)

# Sample from inverse Gaussian using selected parameters
# Based on InverseGaussianNP.sample implementation
nu = jr.normal(ig_key, x_prev.shape)
u = jr.uniform(jr.split(ig_key)[0], x_prev.shape)
y = xp.square(nu)
x_ig = mu + 0.5 * xp.square(mu) / lambda_ * y - (
mu / (2.0 * lambda_) * xp.sqrt(4.0 * mu * lambda_ * y +
xp.square(mu) * xp.square(y)))

# Apply the acceptance criterion
x_new = xp.where(u <= mu / (mu + x_ig), x_ig, xp.square(mu) / x_ig)
return (x_new, subkey), x_new

# Run burn-in iterations
carry = (xp.ones(self.shape), key)
(x, key), _ = lax.scan(gibbs_update, carry, xp.arange(burnin))

# Run actual sampling iterations
num_samples = jnp.prod(jnp.array(shape)) if shape else 1
_, samples = lax.scan(gibbs_update, (x, key), xp.arange(num_samples))
return samples.reshape(full_shape)


@dataclass
class GeneralizedInverseGaussianEP(HasEntropyEP[GeneralizedInverseGaussianNP],
ExpToNat[GeneralizedInverseGaussianNP],
ExpectationParametrization[GeneralizedInverseGaussianNP],
Samplable,
SimpleDistribution):
"""The expectation parametrization of the generalized inverse Gaussian distribution.

Args:
mean_log: E[ln x]
mean: E[x]
mean_inv: E[1/x]
"""
mean_log: JaxRealArray = distribution_parameter(ScalarSupport())
mean: JaxRealArray = distribution_parameter(ScalarSupport(ring=positive_support))
mean_inv: JaxRealArray = distribution_parameter(ScalarSupport(ring=positive_support))

@property
@override
def shape(self) -> Shape:
return self.mean.shape

@override
@classmethod
def domain_support(cls) -> ScalarSupport:
return ScalarSupport(ring=positive_support)

@classmethod
@override
def natural_parametrization_cls(cls) -> type[GeneralizedInverseGaussianNP]:
return GeneralizedInverseGaussianNP

@override
def expected_carrier_measure(self) -> JaxRealArray:
xp = self.array_namespace()
return xp.zeros(self.shape)

@override
def sample(self, key: KeyArray, shape: Shape | None = None) -> JaxRealArray:
return self.to_nat().sample(key, shape)


27 changes: 26 additions & 1 deletion efax/_src/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, TypeVar

from array_api_compat import array_namespace
from jax import jit
from jax import jit, grad, custom_jvp
from tensorflow_probability.substrates import jax as tfp
from tjax import JaxComplexArray, JaxRealArray

Expand Down Expand Up @@ -85,6 +85,31 @@ def join_mappings(**field_to_map: Mapping[_T, _V]) -> dict[_T, dict[str, _V]]:
log_ive = tfp.math.log_bessel_ive


# TensorFlow Probability log_bessel_kve is not differentiable with respect to v
# Need to use custom_vjp to define the gradient
@custom_jvp
def log_kve(v: JaxRealArray, z: JaxRealArray) -> JaxRealArray:
return tfp.math.log_bessel_kve(v, z)

# Define the custom JVP (forward-mode) rule
@log_kve.defjvp
def log_kve_jvp(primals: tuple[JaxRealArray, JaxRealArray], tangents: tuple[JaxRealArray, JaxRealArray]) -> tuple[JaxRealArray, JaxRealArray]:
v, z = primals
v_dot, z_dot = tangents

# Use finite difference for the v derivative
eps = 1e-10
grad_v = (tfp.math.log_bessel_kve(v + eps, z) - tfp.math.log_bessel_kve(v - eps, z)) / (2.0 * eps)

# Use automatic differentiation for the z derivative
grad_z = grad(lambda z: tfp.math.log_bessel_kve(v, z))(z)

# Compute the tangent
tangent_out = grad_v * v_dot + grad_z * z_dot

return log_kve(v, z), tangent_out


# Private functions --------------------------------------------------------------------------------
def _parameter_dot_product(x: JaxComplexArray, y: JaxComplexArray, n_axes: int) -> JaxRealArray:
"""Returns the real component of the dot product of the final n_axes axes of two arrays."""
Expand Down
2 changes: 1 addition & 1 deletion examples/maximum_likelihood_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
"""
import jax.numpy as jnp
import jax.random as jr
from tjax import print_generic

from efax import DirichletEP, DirichletNP, MaximumLikelihoodEstimator, parameter_mean
from tjax import print_generic

# Consider a Dirichlet distribution with a given alpha.
alpha = jnp.asarray([2.0, 3.0, 4.0])
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
dependencies = [
"array_api_compat>=1.10",
"array_api_extra>=0.7",
"jax>=0.4.34,<0.6",
"jax>=0.4.34",
"numpy>=1.25",
"optimistix>=0.0.9",
"optype>=0.8.0",
Expand All @@ -47,7 +47,6 @@ dev = [
"pre-commit>=4",
"pylint>=3.3",
"pyright>=0.0.13",
"pytest-ordering",
"pytest-xdist[psutil]>=3",
"pytest>=8",
"ruff>=0.9.10",
Expand Down
9 changes: 5 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
HasGeneralizedConjugatePrior, IntegralRing, JointDistribution, Samplable,
Structure)

from .create_info import (BetaInfo, ChiSquareInfo, DirichletInfo, GammaInfo,
GeneralizedDirichletInfo, InverseGammaInfo, JointInfo, create_infos)
from .create_info import (BetaInfo, ChiSquareInfo, ComplexCircularlySymmetricNormalInfo,
DirichletInfo, GammaInfo, GeneralizedDirichletInfo, InverseGammaInfo,
InverseGaussianInfo, JointInfo, create_infos)


@pytest.fixture(autouse=True, scope='session')
Expand Down Expand Up @@ -100,8 +101,8 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
Samplable)
if not any_integral_supports(structure)
if not isinstance(info,
BetaInfo | ChiSquareInfo | DirichletInfo | GammaInfo
| InverseGammaInfo | JointInfo)
ComplexCircularlySymmetricNormalInfo | BetaInfo | DirichletInfo
| ChiSquareInfo | GammaInfo | InverseGammaInfo | JointInfo)
if info.tests_selected(distribution_name_option)]
ids = [f"{info.name()}{'NP' if natural else 'EP'}" for info, natural in p]
metafunc.parametrize(("sampling_wc_distribution_info", "natural"), p,
Expand Down
22 changes: 21 additions & 1 deletion tests/create_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
ComplexNormalEP, ComplexNormalNP, ComplexUnitVarianceNormalEP,
ComplexUnitVarianceNormalNP, DirichletEP, DirichletNP, ExponentialEP,
ExponentialNP, GammaEP, GammaNP, GeneralizedDirichletEP, GeneralizedDirichletNP,
GeometricEP, GeometricNP, InverseGammaEP, InverseGammaNP, InverseGaussianEP,
GeometricEP, GeometricNP, GeneralizedInverseGaussianEP, GeneralizedInverseGaussianNP,
InverseGammaEP, InverseGammaNP, InverseGaussianEP,
InverseGaussianNP, IsotropicNormalEP, IsotropicNormalNP, JointDistributionE,
JointDistributionN, LogarithmicEP, LogarithmicNP, LogNormalEP, LogNormalNP,
MultivariateDiagonalNormalEP, MultivariateDiagonalNormalNP,
Expand Down Expand Up @@ -236,6 +237,24 @@ def nat_class(self) -> type[GeneralizedDirichletNP]:
return GeneralizedDirichletNP


class GeneralizedInverseGaussianInfo(DistributionInfo[GeneralizedInverseGaussianNP, GeneralizedInverseGaussianEP, NumpyRealArray]):
@override
def nat_to_scipy_distribution(self, q: GeneralizedInverseGaussianNP) -> Any:
# Convert natural parameters to standard parameters
p = q.p_minus_one + 1.0
b = -2.0 * q.negative_a_over_two
scale = -2.0 * q.negative_b_over_two
return ss.geninvgauss(p=p, b=b, scale=scale)

@override
def exp_class(self) -> type[GeneralizedInverseGaussianEP]:
return GeneralizedInverseGaussianEP

@override
def nat_class(self) -> type[GeneralizedInverseGaussianNP]:
return GeneralizedInverseGaussianNP


class GeometricInfo(DistributionInfo[GeometricNP, GeometricEP, NumpyRealArray]):
@override
def exp_to_scipy_distribution(self, p: GeometricEP) -> Any:
Expand Down Expand Up @@ -652,6 +671,7 @@ def create_infos() -> list[DistributionInfo[Any, Any, Any]]:
ExponentialInfo(),
GammaInfo(),
GeneralizedDirichletInfo(dimensions=5),
GeneralizedInverseGaussianInfo(),
GeometricInfo(),
InverseGammaInfo(),
InverseGaussianInfo(),
Expand Down
Loading