Skip to content

Commit 060ef0e

Browse files
authored
Merge pull request #58 from dkirkby/sparse
Implement optional sparse Gaussian covariance
2 parents 99edb4f + 5fc49b6 commit 060ef0e

File tree

9 files changed

+1741
-1161
lines changed

9 files changed

+1741
-1161
lines changed

design.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,17 +150,18 @@ with code styling again. Here are the steps to follow:
150150
- Install `black` and `pre-commit`:
151151
```bash
152152
$ pip install --user black pre-commit reorder_python_imports
153+
$ pre-commit install
153154
```
154-
`pre-commit` will be tasked with automatically running `black` formatting
155-
whenever you commit some code.
155+
`pre-commit` will be tasked with automatically running `black` and `reorder_python_imports` formatting
156+
whenever you commit some code. The import guidelines are documented [here](https://github.com/asottile/reorder_python_imports#what-does-it-do).
156157

157158
- Manually running black formatting:
158159
```bash
159160
$ black .
160161
```
161162
from the root directory.
162163

163-
- Automatically running black at each commit: You actually have nothing
164+
- Automatically running `black` and `reorder_python_imports` at each commit: You actually have nothing
164165
else to do. If pre-commit is installed it will happen automatically for
165166
you.
166167

docs/notebooks/jax-cosmo-intro.ipynb

Lines changed: 1144 additions & 1140 deletions
Large diffs are not rendered by default.

jax_cosmo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
import jax_cosmo.transfer as transfer
2323
from jax_cosmo.core import *
2424
from jax_cosmo.parameters import *
25+
import jax_cosmo.sparse as sparse

jax_cosmo/angular_cl.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,19 @@ def get_noise_cl(inds):
122122
return lax.map(get_noise_cl, cl_index)
123123

124124

125-
def gaussian_cl_covariance(ell, probes, cl_signal, cl_noise, f_sky=0.25):
125+
def gaussian_cl_covariance(ell, probes, cl_signal, cl_noise, f_sky=0.25, sparse=True):
126126
"""
127127
Computes a Gaussian covariance for the angular cls of the provided probes
128128
129+
Set sparse True to return a sparse matrix representation that uses a factor
130+
of n_ell less memory and is compatible with the linear algebra operations
131+
in :mod:`jax_cosmo.sparse`.
132+
129133
return_cls: (returns covariance)
130134
"""
131135
ell = np.atleast_1d(ell)
132136
n_ell = len(ell)
137+
one = 1.0 if sparse else np.eye(n_ell)
133138

134139
# Adding noise to auto-spectra
135140
cl_obs = cl_signal + cl_noise
@@ -144,15 +149,22 @@ def gaussian_cl_covariance(ell, probes, cl_signal, cl_noise, f_sky=0.25):
144149
def get_cov_block(inds):
145150
a, b, c, d = inds
146151
cov = (cl_obs[a] * cl_obs[b] + cl_obs[c] * cl_obs[d]) / norm
147-
return cov * np.eye(n_ell)
152+
return cov * one
148153

154+
# Return a sparse representation of the matrix containing only the diagonals
155+
# for each of the n_cls x n_cls blocks of size n_ell x n_ell.
156+
# We could compress this further using the symmetry of the blocks, but
157+
# it is easier to invert this matrix with this redundancy included.
149158
cov_mat = lax.map(get_cov_block, cov_blocks)
150159

151160
# Reshape covariance matrix into proper matrix
152-
cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell, n_ell))
153-
cov_mat = cov_mat.transpose(axes=(0, 2, 1, 3)).reshape(
154-
(n_ell * n_cls, n_ell * n_cls)
155-
)
161+
if sparse:
162+
cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell))
163+
else:
164+
cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell, n_ell))
165+
cov_mat = cov_mat.transpose(axes=(0, 2, 1, 3)).reshape(
166+
(n_ell * n_cls, n_ell * n_cls)
167+
)
156168
return cov_mat
157169

158170

@@ -163,10 +175,15 @@ def gaussian_cl_covariance_and_mean(
163175
transfer_fn=tklib.Eisenstein_Hu,
164176
nonlinear_fn=power.halofit,
165177
f_sky=0.25,
178+
sparse=False,
166179
):
167180
"""
168181
Computes a Gaussian covariance for the angular cls of the provided probes
169182
183+
Set sparse True to return a sparse matrix representation that uses a factor
184+
of n_ell less memory and is compatible with the linear algebra operations
185+
in :mod:`jax_cosmo.sparse`.
186+
170187
return_cls: (returns signal + noise cl, covariance)
171188
"""
172189
ell = np.atleast_1d(ell)
@@ -179,6 +196,6 @@ def gaussian_cl_covariance_and_mean(
179196
cl_noise = noise_cl(ell, probes)
180197

181198
# retrieve the covariance
182-
cov_mat = gaussian_cl_covariance(ell, probes, cl_signal, cl_noise, f_sky)
199+
cov_mat = gaussian_cl_covariance(ell, probes, cl_signal, cl_noise, f_sky, sparse)
183200

184201
return cl_signal.flatten(), cov_mat

jax_cosmo/likelihood.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,60 @@
66
import jax.numpy as np
77
import jax.scipy as sp
88

9+
import jax_cosmo.sparse as sparse
910
from jax_cosmo.angular_cl import gaussian_cl_covariance
1011

1112

1213
def gaussian_log_likelihood(data, mu, C, constant_cov=True, inverse_method="inverse"):
1314
"""
14-
Computes the likelihood for some cl
15+
Computes the log likelihood for a given data vector under a multivariate
16+
Gaussian distribution.
17+
18+
If the covariance C is sparse (according to :meth:`jax_cosmo.sparse.is_sparse`)
19+
use sparse inverse and determinant algorithms (and ignore ``inverse_method``).
20+
21+
Parameters
22+
----------
23+
data: array_like
24+
Data vector, with shape [N].
25+
26+
mu: array_like, 1d
27+
Mean of the Gaussian likelihood, with shape [N].
28+
29+
C: array_like or sparse matrix
30+
Covariance of Gaussian likelihood with shape [N,N]
31+
32+
constant_cov: boolean
33+
Whether to include the log determinant of the covariance matrix in the
34+
likelihood. If `constant_cov` is true, the log determinant is ignored
35+
(default: True)
36+
37+
inverse_method: string
38+
Methods for computing the precision matrix. Either "inverse", "cholesky".
39+
Note that this option is ignored when the covariance is sparse. (default: "inverse")
1540
"""
1641
# Computes residuals
1742
r = mu - data
1843

19-
# TODO: check what is the fastest and works the best between cholesky+solve
20-
# and just inversion
21-
if inverse_method == "inverse":
22-
y = np.dot(np.linalg.inv(C), r)
23-
elif inverse_method == "cholesky":
24-
y = sp.linalg.cho_solve(sp.linalg.cho_factor(C, lower=True), r)
44+
if sparse.is_sparse(C):
45+
r = r.reshape(-1, 1)
46+
rT_Cinv_r = sparse.dot(r.T, sparse.inv(C), r)[0, 0]
2547
else:
26-
raise NotImplementedError
48+
# TODO: check what is the fastest and works the best between cholesky+solve
49+
# and just inversion
50+
if inverse_method == "inverse":
51+
y = np.dot(np.linalg.inv(C), r)
52+
elif inverse_method == "cholesky":
53+
y = sp.linalg.cho_solve(sp.linalg.cho_factor(C, lower=True), r)
54+
else:
55+
raise NotImplementedError
56+
rT_Cinv_r = r.dot(y)
2757

2858
if constant_cov:
29-
return -0.5 * r.dot(y)
59+
return -0.5 * rT_Cinv_r
3060
else:
31-
_, logdet = np.linalg.slogdet(C)
32-
return -0.5 * r.dot(y) - 0.5 * logdet
61+
if sparse.is_sparse(C):
62+
_, logdet = sparse.slogdet(C)
63+
else:
64+
_, logdet = np.linalg.slogdet(C)
65+
return -0.5 * (rT_Cinv_r - logdet)

0 commit comments

Comments
 (0)