Skip to content

Commit e2ef704

Browse files
authored
Skeleton DistributionFamilys (#12)
* Ignore the .vscode folder before I inevitably commit it at least once * Skeleton base class for distributions * Base class for distribution families * Distributions sample method is backend-agnostic * Fix return type of sample method * Immutable type for storing backend-compatibility info * Remove distributon family for now * Clean up imports and typevars * Basic docstrings * Give useful information in docstrings * Write tests for SampleCompatibility class * Write test for Distribution * tests now require multiple backends * Create a Translator class as you're going to need to do this a lot, Will * Fix the tests that I broke * Add some more detailed docstrings * Test the translation method of translators * Fix docstring linting issues * Module name to singular to be consistent with #5 * Skeleton for distribution family * Add test for distribution family builder method * Expose getter method for Distribution backends * Translator is optional * Fix docstring * Add note about how distribution families are created
1 parent 634dabf commit e2ef704

File tree

5 files changed

+133
-5
lines changed

5 files changed

+133
-5
lines changed

src/causalprog/distribution/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def __init__(
6464
)
6565
self._backend_translator.validate_compatible(backend_distribution)
6666

67+
def get_dist(self) -> SupportsSampling:
68+
"""Access to the backend distribution."""
69+
return self._dist
70+
6771
def sample(self, rng_key: SupportsRNG, sample_shape: ArrayLike = ()) -> ArrayLike:
6872
"""
6973
Draw samples from the distribution.

src/causalprog/distribution/family.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Parametrised groups of ``Distribution``s."""
2+
3+
from collections.abc import Callable
4+
from typing import Generic, TypeVar
5+
6+
from numpy.typing import ArrayLike
7+
8+
from causalprog.distribution.base import Distribution, SupportsSampling
9+
from causalprog.utils.translator import Translator
10+
11+
CreatesDistribution = TypeVar(
12+
"CreatesDistribution", bound=Callable[..., SupportsSampling]
13+
)
14+
15+
16+
class DistributionFamily(Generic[CreatesDistribution]):
17+
r"""
18+
A family of ``Distributions``, that share the same parameters.
19+
20+
A ``DistributionFamily`` is essentially a ``Distribution`` that has not yet had its
21+
parameter values explicitly specified. Explicit values for the parameters can be
22+
passed to a ``DistributionFamily``'s ``construct`` method, which will then proceed
23+
to construct a ``Distribution`` with those parameter values.
24+
25+
As an explicit example, the (possibly multivariate) normal distribution is
26+
parametrised by two quantities - the (vector of) mean values $\mu$ and covariates
27+
$\Sigma$. A ``DistributionFamily`` represents this general
28+
$\mathcal{N}(\mu, \Sigma)$ parametrised form, however without explicit $\mu$ and
29+
$\Sigma$ values we cannot perform operations like drawing samples. Specifying, for
30+
example, $\mu = 0$ and $\Sigma = 1$ by invoking ``.construct(0., 1.)`` will return a
31+
``Distribution`` instance representing $\mathcal{N}(0., 1.)$, which can then have
32+
samples drawn from it.
33+
"""
34+
35+
_family: CreatesDistribution
36+
_family_translator: Translator | None
37+
38+
@property
39+
def _member(self) -> Callable[..., Distribution]:
40+
"""Constructor method for family members, given parameters."""
41+
return lambda *parameters: Distribution(
42+
self._family(*parameters), backend_translator=self._family_translator
43+
)
44+
45+
def __init__(
46+
self,
47+
backend_family: CreatesDistribution,
48+
backend_translator: Translator | None = None,
49+
) -> None:
50+
"""
51+
Create a new family of distributions.
52+
53+
Args:
54+
backend_family (CreatesDistribution): Backend callable that assembles the
55+
distribution, given explicit parameter values. Currently, this callable
56+
can only accept the parameters as a sequence of positional arguments.
57+
backend_translator (Translator): ``Translator`` instance that to be
58+
passed to the ``Distribution`` constructor.
59+
60+
"""
61+
self._family = backend_family
62+
self._family_translator = backend_translator
63+
64+
def construct(self, *parameters: ArrayLike) -> Distribution:
65+
"""
66+
Create a distribution from an explicit set of parameters.
67+
68+
Args:
69+
*parameters (ArrayLike): Parameters that define a member of this family,
70+
passed as sequential arguments.
71+
72+
"""
73+
return self._member(*parameters)

tests/test_distributions/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import jax.numpy as jnp
2+
import jax.random as jrn
3+
import pytest
4+
from jax._src.basearray import Array
5+
6+
7+
@pytest.fixture
8+
def seed() -> int:
9+
return 0
10+
11+
12+
@pytest.fixture
13+
def rng_key(seed: int):
14+
return jrn.key(seed)
15+
16+
17+
@pytest.fixture
18+
def n_dim_std_normal(request) -> tuple[Array, Array]:
19+
"""
20+
Mean and covariance matrix of the n-dimensional standard normal distribution.
21+
22+
``request.param`` should be an integer corresponding to the number of dimensions.
23+
"""
24+
n_dims = request.param
25+
mean = jnp.array([0.0] * n_dims)
26+
cov = jnp.diag(jnp.array([1.0] * n_dims))
27+
return mean, cov

tests/test_distributions/test_different_backends.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
import distrax
44
import jax.numpy as jnp
5-
import jax.random as jrn
65
from numpyro.distributions.continuous import MultivariateNormal
76

87
from causalprog.distribution.base import Distribution, SampleTranslator
98

109

11-
def test_different_backends() -> None:
10+
def test_different_backends(rng_key) -> None:
1211
"""
1312
Test that ``Distribution`` can use different (but equivalent) backends.
1413
@@ -20,15 +19,14 @@ def test_different_backends() -> None:
2019
n_dims = 2
2120
mean = jnp.array([0.0] * n_dims)
2221
cov = jnp.diag(jnp.array([1.0] * n_dims))
23-
rng = jrn.key(0)
2422
sample_size = (10, 5)
2523

2624
distrax_normal = distrax.MultivariateNormalFullCovariance(mean, cov)
2725
distrax_dist = Distribution(distrax_normal, SampleTranslator(rng_key="seed"))
28-
distrax_samples = distrax_dist.sample(rng, sample_size)
26+
distrax_samples = distrax_dist.sample(rng_key, sample_size)
2927

3028
npyo_normal = MultivariateNormal(mean, cov)
3129
npyo_dist = Distribution(npyo_normal, SampleTranslator(rng_key="key"))
32-
npyo_samples = npyo_dist.sample(rng, sample_size)
30+
npyo_samples = npyo_dist.sample(rng_key, sample_size)
3331

3432
assert jnp.allclose(distrax_samples, npyo_samples)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import distrax
2+
import pytest
3+
4+
from causalprog.distribution.base import SampleTranslator
5+
from causalprog.distribution.family import DistributionFamily
6+
7+
8+
@pytest.mark.parametrize(
9+
("n_dim_std_normal"),
10+
[pytest.param(2, id="2D normal")],
11+
indirect=["n_dim_std_normal"],
12+
)
13+
def test_builder_matches_backend(n_dim_std_normal) -> None:
14+
"""
15+
Test that building from a family is equivalent
16+
to building via the backend explicitly.
17+
18+
"""
19+
mnv = distrax.MultivariateNormalFullCovariance
20+
21+
mnv_family = DistributionFamily(mnv, SampleTranslator(rng_key="seed"))
22+
via_family = mnv_family.construct(*n_dim_std_normal)
23+
via_backend = mnv(*n_dim_std_normal)
24+
25+
assert via_backend.kl_divergence(via_family.get_dist()) == pytest.approx(0.0)
26+
assert via_family.get_dist().kl_divergence(via_backend) == pytest.approx(0.0)

0 commit comments

Comments
 (0)