Skip to content

Commit d86bf6b

Browse files
authored
Merge pull request #54 from DifferentiableUniverseInitiative/u/EiffL/spline_integration
Updates background to use new cubic splines
2 parents 850fae8 + 1f6f413 commit d86bf6b

File tree

4 files changed

+48
-39
lines changed

4 files changed

+48
-39
lines changed

jax_cosmo/angular_cl.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import jax_cosmo.power as power
1616
import jax_cosmo.transfer as tklib
1717
from jax_cosmo.scipy.integrate import simps
18+
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
1819
from jax_cosmo.utils import a2z
1920
from jax_cosmo.utils import z2a
2021

@@ -54,7 +55,7 @@ def find_index(a, b):
5455

5556

5657
def angular_cl(
57-
cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit
58+
cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.linear
5859
):
5960
"""
6061
Computes angular Cls for the provided probes
@@ -93,13 +94,16 @@ def combine_kernels(inds):
9394

9495
# Now kernels has shape [ncls, na]
9596
kernels = lax.map(combine_kernels, cl_index)
96-
9797
result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi ** 2, 1.0)
98+
return result
9899

99-
# We transpose the result just to make sure that na is first
100-
return result.T
101-
102-
return simps(integrand, z2a(zmax), 1.0, 512) / const.c ** 2
100+
atab = np.linspace(z2a(zmax), 1.0, 64)
101+
eval_integral = vmap(
102+
lambda x: np.squeeze(
103+
InterpolatedUnivariateSpline(atab, x).integral(z2a(zmax), 1.0)
104+
)
105+
)
106+
return eval_integral(integrand(atab)) / const.c ** 2
103107

104108
return cl(ell)
105109

@@ -161,7 +165,7 @@ def gaussian_cl_covariance_and_mean(
161165
ell,
162166
probes,
163167
transfer_fn=tklib.Eisenstein_Hu,
164-
nonlinear_fn=power.halofit,
168+
nonlinear_fn=power.linear,
165169
f_sky=0.25,
166170
):
167171
"""

jax_cosmo/background.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import jax.numpy as np
77

88
import jax_cosmo.constants as const
9-
from jax_cosmo.scipy.interpolate import interp
9+
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
1010
from jax_cosmo.scipy.ode import odeint
1111

1212
__all__ = [
@@ -202,7 +202,7 @@ def Omega_de_a(cosmo, a):
202202
return cosmo.Omega_de * np.power(a, f_de(cosmo, a)) / Esqr(cosmo, a)
203203

204204

205-
def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256):
205+
def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=64):
206206
r"""Radial comoving distance in [Mpc/h] for a given scale factor.
207207
208208
Parameters
@@ -235,17 +235,19 @@ def dchioverdlna(y, x):
235235
return dchioverda(cosmo, xa) * xa
236236

237237
chitab = odeint(dchioverdlna, 0.0, np.log(atab))
238-
# np.clip(- 3000*np.log(atab), 0, 10000)#odeint(dchioverdlna, 0., np.log(atab), cosmo)
239238
chitab = chitab[-1] - chitab
240239

241-
cache = {"a": atab, "chi": chitab}
240+
cache = {
241+
"a2chi": InterpolatedUnivariateSpline(atab, chitab),
242+
"chi2a": InterpolatedUnivariateSpline(chitab, atab),
243+
}
242244
cosmo._workspace["background.radial_comoving_distance"] = cache
243245
else:
244246
cache = cosmo._workspace["background.radial_comoving_distance"]
245247

246248
a = np.atleast_1d(a)
247249
# Return the results as an interpolation of the table
248-
return np.clip(interp(a, cache["a"], cache["chi"]), 0.0)
250+
return np.clip(cache["a2chi"](a), 0.0)
249251

250252

251253
def a_of_chi(cosmo, chi):
@@ -270,7 +272,7 @@ def a_of_chi(cosmo, chi):
270272
radial_comoving_distance(cosmo, 1.0)
271273
cache = cosmo._workspace["background.radial_comoving_distance"]
272274
chi = np.atleast_1d(chi)
273-
return interp(chi, cache["chi"], cache["a"])
275+
return cache["chi2a"](chi)
274276

275277

276278
def dchioverda(cosmo, a):
@@ -437,7 +439,7 @@ def growth_rate(cosmo, a):
437439
return _growth_rate_ODE(cosmo, a)
438440

439441

440-
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
442+
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=64):
441443
""" Compute linear growth factor D(a) at a given scale factor,
442444
normalised such that D(a=1) = 1.
443445
@@ -478,11 +480,14 @@ def D_derivs(y, x):
478480
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
479481
ftab = y[:, 1] / y1[-1] * atab / gtab
480482

481-
cache = {"a": atab, "g": gtab, "f": ftab}
483+
cache = {
484+
"g": InterpolatedUnivariateSpline(atab, gtab),
485+
"f": InterpolatedUnivariateSpline(atab, ftab),
486+
}
482487
cosmo._workspace["background.growth_factor"] = cache
483488
else:
484489
cache = cosmo._workspace["background.growth_factor"]
485-
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
490+
return np.clip(cache["g"](a), 0.0, 1.0)
486491

487492

488493
def _growth_rate_ODE(cosmo, a):
@@ -506,10 +511,10 @@ def _growth_rate_ODE(cosmo, a):
506511
if not "background.growth_factor" in cosmo._workspace.keys():
507512
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
508513
cache = cosmo._workspace["background.growth_factor"]
509-
return interp(a, cache["a"], cache["f"])
514+
return cache["f"](a)
510515

511516

512-
def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
517+
def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=64):
513518
r""" Computes growth factor by integrating the growth rate provided by the
514519
\gamma parametrization. Normalized such that D( a=1) =1
515520
@@ -538,11 +543,11 @@ def integrand(y, loga):
538543

539544
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
540545
gtab = gtab / gtab[-1] # Normalize to a=1.
541-
cache = {"a": atab, "g": gtab}
546+
cache = {"g": InterpolatedUnivariateSpline(atab, gtab)}
542547
cosmo._workspace["background.growth_factor"] = cache
543548
else:
544549
cache = cosmo._workspace["background.growth_factor"]
545-
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
550+
return np.clip(cache["g"](a), 0.0, 1.0)
546551

547552

548553
def _growth_rate_gamma(cosmo, a):

tests/test_angular_cl.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_lensing_cl():
2222
n_s=0.96,
2323
Neff=0,
2424
transfer_function="eisenstein_hu",
25-
matter_power_spectrum="halofit",
25+
matter_power_spectrum="linear",
2626
)
2727

2828
cosmo_jax = Cosmology(
@@ -43,13 +43,13 @@ def test_lensing_cl():
4343
tracer_jax = probes.WeakLensing([nz])
4444

4545
# Get an ell range for the cls
46-
ell = np.logspace(0.1, 4)
46+
ell = np.logspace(1, 4)
4747

4848
# Compute the cls
4949
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)
5050
cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax])
5151

52-
assert_allclose(cl_ccl, cl_jax[0], rtol=1.0e-2)
52+
assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2)
5353

5454

5555
def test_lensing_cl_IA():
@@ -62,7 +62,7 @@ def test_lensing_cl_IA():
6262
n_s=0.96,
6363
Neff=0,
6464
transfer_function="eisenstein_hu",
65-
matter_power_spectrum="halofit",
65+
matter_power_spectrum="linear",
6666
)
6767

6868
cosmo_jax = Cosmology(
@@ -90,13 +90,13 @@ def test_lensing_cl_IA():
9090
tracer_jax = probes.WeakLensing([nz], bias)
9191

9292
# Get an ell range for the cls
93-
ell = np.logspace(0.1, 4)
93+
ell = np.logspace(1, 4)
9494

9595
# Compute the cls
9696
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)
9797
cl_jax = angular_cl(cosmo_jax, ell, [tracer_jax])
9898

99-
assert_allclose(cl_ccl, cl_jax[0], rtol=1e-2)
99+
assert_allclose(cl_ccl, cl_jax[0], rtol=5e-3)
100100

101101

102102
def test_clustering_cl():
@@ -109,7 +109,7 @@ def test_clustering_cl():
109109
n_s=0.96,
110110
Neff=0,
111111
transfer_function="eisenstein_hu",
112-
matter_power_spectrum="halofit",
112+
matter_power_spectrum="linear",
113113
)
114114

115115
cosmo_jax = Cosmology(
@@ -136,7 +136,7 @@ def test_clustering_cl():
136136
tracer_jax = probes.NumberCounts([nz], bias)
137137

138138
# Get an ell range for the cls
139-
ell = np.logspace(0.1, 4)
139+
ell = np.logspace(1, 4)
140140

141141
# Compute the cls
142142
cl_ccl = ccl.angular_cl(cosmo_ccl, tracer_ccl, tracer_ccl, ell)

tests/test_background.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def test_distances_flat():
3636

3737
chi_ccl = ccl.comoving_radial_distance(cosmo_ccl, a)
3838
chi_jax = bkgrd.radial_comoving_distance(cosmo_jax, a) / cosmo_jax.h
39-
assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2)
39+
assert_allclose(chi_ccl, chi_jax, rtol=1e-3)
4040

4141
chi_ccl = ccl.comoving_angular_distance(cosmo_ccl, a)
4242
chi_jax = bkgrd.transverse_comoving_distance(cosmo_jax, a) / cosmo_jax.h
43-
assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2)
43+
assert_allclose(chi_ccl, chi_jax, rtol=1e-3)
4444

4545
chi_ccl = ccl.angular_diameter_distance(cosmo_ccl, a)
4646
chi_jax = bkgrd.angular_diameter_distance(cosmo_jax, a) / cosmo_jax.h
47-
assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2)
47+
assert_allclose(chi_ccl, chi_jax, rtol=1e-3)
4848

4949

5050
def test_growth():
@@ -72,12 +72,12 @@ def test_growth():
7272
)
7373

7474
# Test array of scale factors
75-
a = np.linspace(0.01, 1.0)
75+
a = np.linspace(0.1, 1.0)
7676

7777
gccl = ccl.growth_factor(cosmo_ccl, a)
7878
gjax = bkgrd.growth_factor(cosmo_jax, a)
7979

80-
assert_allclose(gccl, gjax, rtol=1e-2)
80+
assert_allclose(gccl, gjax, rtol=1e-3)
8181

8282

8383
def test_growth_rate():
@@ -105,12 +105,12 @@ def test_growth_rate():
105105
)
106106

107107
# Test array of scale factors
108-
a = np.linspace(0.01, 1.0)
108+
a = np.linspace(0.1, 1.0)
109109

110110
fccl = ccl.growth_rate(cosmo_ccl, a)
111111
fjax = bkgrd.growth_rate(cosmo_jax, a)
112112

113-
assert_allclose(fccl, fjax, rtol=1e-2)
113+
assert_allclose(fccl, fjax, rtol=1e-3)
114114

115115

116116
def test_growth_rate_gamma():
@@ -140,12 +140,12 @@ def test_growth_rate_gamma():
140140
)
141141

142142
# Test array of scale factors
143-
a = np.linspace(0.01, 1.0)
143+
a = np.linspace(0.1, 1.0)
144144

145145
fccl = ccl.growth_rate(cosmo_ccl, a)
146146
fjax = bkgrd.growth_rate(cosmo_jax, a)
147147

148-
assert_allclose(fccl, fjax, rtol=1e-2)
148+
assert_allclose(fccl, fjax, rtol=5e-3)
149149

150150

151151
def test_growth_gamma():
@@ -174,9 +174,9 @@ def test_growth_gamma():
174174
)
175175

176176
# Test array of scale factors
177-
a = np.linspace(0.01, 1.0)
177+
a = np.linspace(0.1, 1.0)
178178

179179
gccl = ccl.growth_factor(cosmo_ccl, a)
180180
gjax = bkgrd.growth_factor(cosmo_jax, a)
181181

182-
assert_allclose(gccl, gjax, rtol=1e-2)
182+
assert_allclose(gccl, gjax, rtol=1e-3)

0 commit comments

Comments
 (0)