Skip to content

Add convenience classes for the (multivariate) normal. #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fd205ef
Ignore the .vscode folder before I inevitably commit it at least once
willGraham01 Mar 6, 2025
449d597
Skeleton base class for distributions
willGraham01 Mar 6, 2025
f69e4f5
Base class for distribution families
willGraham01 Mar 6, 2025
91af7b7
Distributions sample method is backend-agnostic
willGraham01 Mar 6, 2025
2a57d56
Fix return type of sample method
willGraham01 Mar 6, 2025
bdd7be3
Immutable type for storing backend-compatibility info
willGraham01 Mar 6, 2025
2842c89
Remove distributon family for now
willGraham01 Mar 6, 2025
1327588
Clean up imports and typevars
willGraham01 Mar 6, 2025
e67d4b8
Basic docstrings
willGraham01 Mar 6, 2025
2cc3451
Give useful information in docstrings
willGraham01 Mar 6, 2025
2a7adf7
Write tests for SampleCompatibility class
willGraham01 Mar 6, 2025
db9a3ff
Write test for Distribution
willGraham01 Mar 6, 2025
609dfd8
tests now require multiple backends
willGraham01 Mar 6, 2025
e3fb4d5
Create a Translator class as you're going to need to do this a lot, Will
willGraham01 Mar 7, 2025
2cdb0f2
Fix the tests that I broke
willGraham01 Mar 7, 2025
d8ffaae
Add some more detailed docstrings
willGraham01 Mar 7, 2025
9c69264
Test the translation method of translators
willGraham01 Mar 7, 2025
46ac6e8
Fix docstring linting issues
willGraham01 Mar 10, 2025
b74a477
Module name to singular to be consistent with #5
willGraham01 Mar 10, 2025
ccc343a
Skeleton for distribution family
willGraham01 Mar 10, 2025
f513896
Add test for distribution family builder method
willGraham01 Mar 10, 2025
03d3845
Expose getter method for Distribution backends
willGraham01 Mar 10, 2025
7ceb2e0
Translator is optional
willGraham01 Mar 11, 2025
b5133b3
Merge branch 'main' into wgraham/distribution-family
willGraham01 Mar 13, 2025
8383d54
Fix docstring
willGraham01 Mar 13, 2025
c7af397
Add note about how distribution families are created
willGraham01 Mar 13, 2025
9469d1a
Add jax to explicit package dependencies
willGraham01 Mar 13, 2025
0137dfd
Add normal distribution as a standard distribution
willGraham01 Mar 13, 2025
bb01b6a
Create multivariate normal family, that can be attached to a Node
willGraham01 Mar 13, 2025
0b219ed
Update DistFam docstring to better highlight the purpose of a family
willGraham01 Mar 13, 2025
18d6108
Merge branch 'main' into wgraham/sample-normal-dists
willGraham01 Mar 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"Typing :: Typed",
]
dependencies = [
"networkx",
]
dependencies = ["jax", "networkx"]
description = "A Python package for causal modelling and inference with stochastic causal programming"
dynamic = ["version"]
keywords = []
Expand Down
88 changes: 88 additions & 0 deletions src/causalprog/distribution/normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""(Multivariate) normal distribution, implemented via ``jax.random`` backend."""

from typing import TypeAlias, TypeVar

import jax.numpy as jnp
import jax.random as jrn
from jax import Array as JaxArray
from numpy.typing import ArrayLike

from .base import Distribution
from .family import DistributionFamily

ArrayCompatible = TypeVar("ArrayCompatible", JaxArray, ArrayLike)
RNGKey: TypeAlias = JaxArray


class _Normal:
mean: JaxArray
cov: JaxArray

def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None:
self.mean = jnp.array(mean)
self.cov = jnp.array(cov)

def sample(self, rng_key: RNGKey, sample_shape: ArrayLike) -> JaxArray:
return jrn.multivariate_normal(rng_key, self.mean, self.cov, shape=sample_shape)


class Normal(Distribution):
r"""
A (possibly multivaraiate) normal distribution, $\mathcal{N}(\mu, \Sigma)$.

The normal distribution is parametrised by its (vector of) mean value(s) $\mu$ and
(matrix of) covariate(s) $\Sigma$. These values must be supplied to an instance at
upon construction, and can be accessed via the ``mean`` ($\mu$) and ``cov``
($\Sigma$) attributes, respectively.

"""

_dist: _Normal

@property
def mean(self) -> JaxArray:
r"""Mean of the distribution, $\mu$."""
return self._dist.mean

@property
def cov(self) -> JaxArray:
r"""Covariate matrix of the distribution, $\Sigma$."""
return self._dist.cov

def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None:
r"""
Create a new normal distribution.

Args:
mean (ArrayCompatible): Vector of mean values, $\mu$.
cov (ArrayCompatible): Matrix of covariates, $\Sigma$.

"""
super().__init__(_Normal(mean, cov))


class NormalFamily(DistributionFamily):
r"""
Constructor class for (possibly multivariate) normal distributions.

The multivariate normal distribution is parametrised by a (vector of) mean values
$\mu$, and (matrix of) covariates $\Sigma$. A ``NormalFamily`` represents this
family of distributions, $\mathcal{N}(\mu, \Sigma)$. The ``.construct`` method can
be used to construct a ``Normal`` distribution with a fixed mean and covariate
matrix.
"""

def __init__(self) -> None:
"""Create a family of normal distributions."""
super().__init__(Normal)

def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal:
r"""
Construct a normal distribution with the given mean and covariates.

Args:
mean (ArrayCompatible): Vector of mean values, $\mu$.
cov (ArrayCompatible): Matrix of covariates, $\Sigma$.

"""
return super().construct(mean, cov)
Loading