Skip to content

Conversation

xshi19
Copy link
Contributor

@xshi19 xshi19 commented May 4, 2025

Hi Neil,

I am trying to add the generalized inverse Gaussian distribution. However there is an issue with the log bessel function of the second kind in the log normalizer and the to_exp method: https://en.wikipedia.org/wiki/Generalized_inverse_Gaussian_distribution
There is no JAX implementation of this function jax-ml/jax#17038, so I use this package: https://github.com/tk2lab/logbesselk
However it prevents me from using jvp in the log normalizer. In addition it does not support jax.grad: tk2lab/logbesselk#33, so in the current version I simply use the finite difference

dlogk_dp = (logk(p + eps, z) - logk(p - eps, z)) / (2.0 * eps)

which is not accurate.
Please let me know if you have any suggestions on this issue. Tensorflow has an implementation of the log bessel function https://www.tensorflow.org/probability/api_docs/python/tfp/math/log_bessel_kve, not sure if there is a way to use it in the current framework.
Besides that the pdf and sampling should work well.
image

NeilGirdhar and others added 18 commits April 16, 2025 00:22
Use conditional expression with shape

Correct typo in test

Update RealField to work with arrays

Add the inverse Gaussian distribution

Bump

Split structure classes into 3 files

Polish parameter_supports mecahnism

Factor out triangular number functions

Skip sampling test for inverse Gaussian

Correct issue with scalar support bound shapes

Split parameter package

Correct docstrings

Bump

Eliminate Jax initialization on import

Update examples

Bump

change sample using NP directly

resolve conflicts

Remove commented-out code related to Inverse Gaussian in pytest_generate_tests

resolve conflicts

resolve conflicts
@NeilGirdhar
Copy link
Owner

I am trying to add the generalized inverse Gaussian distribution.

Awesome!! Very cool.

Please let me know if you have any suggestions on this issue.

Have you looked in tensorflow probability? I don't know much about Bessel functions, but is this close to what you're looking for? If not, I suggest you request it there. Also, there was this issue is worth taking a look at, and maybe asking there.

What do you think?

@xshi19
Copy link
Contributor Author

xshi19 commented May 5, 2025

Yes tfp.math.bessel_kve is the one I am looking for. But can I use directly in the log normalizer without breaking ExpToNat? I am new to JAX so not sure how to incorporate in the Jax framework.

@NeilGirdhar
Copy link
Owner

But can I use directly in the log normalizer without breaking ExpToNat? I

Yes, I think so. You can see how I imported other Bessel functions into _src.tools.

@xshi19
Copy link
Contributor Author

xshi19 commented May 7, 2025

Added the log_kve to _src.tools.
However test_distributions.py::test_conversion still fails. Probably because

  • the nature parameters negative_a_over_two and negative_b_over_two must be negative;
  • the Newton method used in ExpToNat does not support constraints
  • the initial parameters initial_search_parameters are zeros

Wonder if there is any plan to

  • implement constrained optimizer
  • customizable initial_search_parameters for different distributions; in the case of GIG, we can use IG's to_nat to get the initial parameters

The other tests in test_distributions.py are passed.

@NeilGirdhar
Copy link
Owner

NeilGirdhar commented May 7, 2025

the Newton method used in ExpToNat does not support constraints

Instead of adding constraints, the trick is to ensure that the flattened parametrization is over the entire plane. It may be a bit hard to understand, but the beta distribution also has constrained parameters, but it has no problem with ExpToNat because its natural parameters have support over a constrained ring. This way, ExpToNat uses RealField.flattened(map_to_plane=True, which is unconstrained.

I see you used this in your distribution, so I wonder why it's not working. I can look at it later this week if you get stuck. Just let me know.

@NeilGirdhar
Copy link
Owner

FYI, these are the errors I get running this PR

======
FAILED tests/test_distributions.py::test_conversion[GeneralizedInverseGaussian] - Failed: Conversion failure
FAILED tests/test_entropy_gradient.py::test_nat_entropy_gradient[GeneralizedInverseGaussian] - TypeError: Gradient only defined for scalar-output functions. Output had shape: (7, 13).
FAILED tests/test_entropy_gradient.py::test_exp_entropy_gradient[GeneralizedInverseGaussian] - Failed: Non-finite gradient found for distributions: list
FAILED tests/test_hessian.py::test_sampling_cotangents[GeneralizedInverseGaussianEP] - NotImplementedError: Differentiation rule for 'random_gamma_grad' not implemented
FAILED tests/test_hessian.py::test_sampling_cotangents[GeneralizedInverseGaussianNP] - TypeError: GeneralizedInverseGaussianNP.sample() missing 1 required positional argument: 'shape'
FAILED tests/test_match_scipy.py::test_nat_entropy[GeneralizedInverseGaussian] - AssertionError: 
FAILED tests/test_match_scipy.py::test_exp_entropy[GeneralizedInverseGaussian] - AssertionError: 
FAILED tests/test_match_scipy.py::test_pdf[GeneralizedInverseGaussian] - AssertionError: 
FAILED tests/test_match_scipy.py::test_maximum_likelihood_estimation[GeneralizedInverseGaussian] - AssertionError: 
FAILED tests/test_sampling.py::test_sampling_and_estimation[GeneralizedInverseGaussianEP] - jax.errors.KeyReuseError: In pjit, argument 0 is already consumed.
FAILED tests/test_sampling.py::test_sampling_and_estimation[GeneralizedInverseGaussianNP] - jax.errors.KeyReuseError: In pjit, argument 0 is already consumed.
FAILED tests/test_shapes.py::test_shapes[GeneralizedInverseGaussian] - ValueError: Domain error in arguments. The `scale` parameter must be positive for all distributions, and many distributions have restrictions on shape parameters. Please see the `scipy.stats.geninvgauss` documentation for details.

You may want to squash and rebase onto main since it's running with some old dependencies.

@xshi19
Copy link
Contributor Author

xshi19 commented May 8, 2025

Thanks! I will take a look at this in the weekend.

@NeilGirdhar
Copy link
Owner

Cool, just checked and it looks like you forgot the constraint on p_minus_one: JaxRealArray = distribution_parameter(ScalarSupport())?

@xshi19
Copy link
Contributor Author

xshi19 commented May 9, 2025

Cool, just checked and it looks like you forgot the constraint on p_minus_one: JaxRealArray = distribution_parameter(ScalarSupport())?

This value should not have constraint, see https://en.wikipedia.org/wiki/Generalized_inverse_Gaussian_distribution

You may want to squash and rebase onto main since it's running with some old dependencies.

I should have the latest update from the main? I see your last change 6a058e6 in my branch.

This seems to be a numerical issue from the log bessel function when taking its derivative from finite difference. I changed the eps from 1e-6 to 1e-10 and the Newton's method converges. However the natural parameters are not the same:

p_minus_one = jnp.array(0.9891) 
negative_a_over_two = jnp.array(-3.5979)
negative_b_over_two = jnp.array(-0.4638)

gig_np = GeneralizedInverseGaussianNP(
    p_minus_one=p_minus_one,
    negative_a_over_two=negative_a_over_two,
    negative_b_over_two=negative_b_over_two
)
gig_ep = gig_np.to_exp()

gig_np_from_ep = gig_ep.to_nat()
print(gig_np_from_ep)
print(gig_np)
print(gig_np_from_ep.to_exp())
print(gig_ep)
GeneralizedInverseGaussianNP(p_minus_one=Array(-2.00653748, dtype=float64), negative_a_over_two=Array(-1.61145512, dtype=float64), negative_b_over_two=Array(-1.30899207, dtype=float64))
GeneralizedInverseGaussianNP(p_minus_one=Array(0.9891, dtype=float64, weak_type=True), negative_a_over_two=Array(-3.5979, dtype=float64, weak_type=True), negative_b_over_two=Array(-0.4638, dtype=float64, weak_type=True))
GeneralizedInverseGaussianEP(mean_log=Array(-0.40242353, dtype=float64), mean=Array(0.77495462, dtype=float64), mean_inv=Array(1.72296084, dtype=float64))
GeneralizedInverseGaussianEP(mean_log=Array(-0.39357565, dtype=float64), mean=Array(0.77495462, dtype=float64), mean_inv=Array(1.72296084, dtype=float64))

Small difference in mean_log ends up with large difference in natural parameters.

The pdfs are not the same. I compared the results to scipy pdf to make sure the pdfs are correct:

p = gig_np.p_minus_one + 1
a = -2 * gig_np.negative_a_over_two
b = -2 * gig_np.negative_b_over_two
gig_sp = geninvgauss(p=p, b=np.sqrt(a*b), scale=np.sqrt(b/a))

p = gig_np_from_ep.p_minus_one + 1
a = -2 * gig_np_from_ep.negative_a_over_two
b = -2 * gig_np_from_ep.negative_b_over_two
gig_sp_from_ep = geninvgauss(p=p, b=np.sqrt(a*b), scale=np.sqrt(b/a))

x_values = np.linspace(0.001, 10, 1000)
plt.figure(figsize=(8, 6))
plt.plot(x_values, gig_np.pdf(x_values), label='jax pdf original')
plt.plot(x_values, gig_sp.pdf(x_values), '--', label='scipy pdf original')
plt.plot(x_values, gig_np_from_ep.pdf(x_values), label='jax pdf NatToExp')
plt.plot(x_values, gig_sp_from_ep.pdf(x_values), '--', label='scipy pdf NatToExp')
plt.legend()

image

When I sample random numbers from the two distributions, the sufficient statistics are highly similar:

n = 100000
x_sample = gig_sp.rvs(size=n)
x_sample_from_ep = gig_sp_from_ep.rvs(size=n)
print(f"mean log: {jnp.log(x_sample).mean()}, {jnp.log(x_sample_from_ep).mean()}")
print(f"mean: {x_sample.mean()}, {x_sample_from_ep.mean()}")
print(f"mean inv: {(1/x_sample).mean()}, {(1/x_sample_from_ep).mean()}")
mean log: -0.39378250590847924, -0.40225233681070754
mean: 0.774341182300226, 0.7746143088905203
mean inv: 1.723141011106831, 1.7219595091758464

Will do more research on this numerical instability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants