Skip to content

Commit e30cae5

Browse files
authored
Add convenience classes for the (multivariate) normal. (#15)
* Ignore the .vscode folder before I inevitably commit it at least once * Skeleton base class for distributions * Base class for distribution families * Distributions sample method is backend-agnostic * Fix return type of sample method * Immutable type for storing backend-compatibility info * Remove distributon family for now * Clean up imports and typevars * Basic docstrings * Give useful information in docstrings * Write tests for SampleCompatibility class * Write test for Distribution * tests now require multiple backends * Create a Translator class as you're going to need to do this a lot, Will * Fix the tests that I broke * Add some more detailed docstrings * Test the translation method of translators * Fix docstring linting issues * Module name to singular to be consistent with #5 * Skeleton for distribution family * Add test for distribution family builder method * Expose getter method for Distribution backends * Translator is optional * Fix docstring * Add note about how distribution families are created * Add jax to explicit package dependencies * Add normal distribution as a standard distribution * Create multivariate normal family, that can be attached to a Node * Update DistFam docstring to better highlight the purpose of a family
1 parent e2ef704 commit e30cae5

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ classifiers = [
1515
"Programming Language :: Python :: 3.13",
1616
"Typing :: Typed",
1717
]
18-
dependencies = [
19-
"networkx",
20-
]
18+
dependencies = ["jax", "networkx"]
2119
description = "A Python package for causal modelling and inference with stochastic causal programming"
2220
dynamic = ["version"]
2321
keywords = []

src/causalprog/distribution/normal.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""(Multivariate) normal distribution, implemented via ``jax.random`` backend."""
2+
3+
from typing import TypeAlias, TypeVar
4+
5+
import jax.numpy as jnp
6+
import jax.random as jrn
7+
from jax import Array as JaxArray
8+
from numpy.typing import ArrayLike
9+
10+
from .base import Distribution
11+
from .family import DistributionFamily
12+
13+
ArrayCompatible = TypeVar("ArrayCompatible", JaxArray, ArrayLike)
14+
RNGKey: TypeAlias = JaxArray
15+
16+
17+
class _Normal:
18+
mean: JaxArray
19+
cov: JaxArray
20+
21+
def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None:
22+
self.mean = jnp.array(mean)
23+
self.cov = jnp.array(cov)
24+
25+
def sample(self, rng_key: RNGKey, sample_shape: ArrayLike) -> JaxArray:
26+
return jrn.multivariate_normal(rng_key, self.mean, self.cov, shape=sample_shape)
27+
28+
29+
class Normal(Distribution):
30+
r"""
31+
A (possibly multivaraiate) normal distribution, $\mathcal{N}(\mu, \Sigma)$.
32+
33+
The normal distribution is parametrised by its (vector of) mean value(s) $\mu$ and
34+
(matrix of) covariate(s) $\Sigma$. These values must be supplied to an instance at
35+
upon construction, and can be accessed via the ``mean`` ($\mu$) and ``cov``
36+
($\Sigma$) attributes, respectively.
37+
38+
"""
39+
40+
_dist: _Normal
41+
42+
@property
43+
def mean(self) -> JaxArray:
44+
r"""Mean of the distribution, $\mu$."""
45+
return self._dist.mean
46+
47+
@property
48+
def cov(self) -> JaxArray:
49+
r"""Covariate matrix of the distribution, $\Sigma$."""
50+
return self._dist.cov
51+
52+
def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None:
53+
r"""
54+
Create a new normal distribution.
55+
56+
Args:
57+
mean (ArrayCompatible): Vector of mean values, $\mu$.
58+
cov (ArrayCompatible): Matrix of covariates, $\Sigma$.
59+
60+
"""
61+
super().__init__(_Normal(mean, cov))
62+
63+
64+
class NormalFamily(DistributionFamily):
65+
r"""
66+
Constructor class for (possibly multivariate) normal distributions.
67+
68+
The multivariate normal distribution is parametrised by a (vector of) mean values
69+
$\mu$, and (matrix of) covariates $\Sigma$. A ``NormalFamily`` represents this
70+
family of distributions, $\mathcal{N}(\mu, \Sigma)$. The ``.construct`` method can
71+
be used to construct a ``Normal`` distribution with a fixed mean and covariate
72+
matrix.
73+
"""
74+
75+
def __init__(self) -> None:
76+
"""Create a family of normal distributions."""
77+
super().__init__(Normal)
78+
79+
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal:
80+
r"""
81+
Construct a normal distribution with the given mean and covariates.
82+
83+
Args:
84+
mean (ArrayCompatible): Vector of mean values, $\mu$.
85+
cov (ArrayCompatible): Matrix of covariates, $\Sigma$.
86+
87+
"""
88+
return super().construct(mean, cov)

0 commit comments

Comments
 (0)