Skip to content

Commit d5c7464

Browse files
authored
Merge pull request #84 from eelregit/eelregit_transverse_distance_patch
Make k and transverse distance jitable
2 parents 96634cc + cf8500c commit d5c7464

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

jax_cosmo/background.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# This module implements various functions for the background COSMOLOGY
22
import jax.numpy as np
3+
from jax import lax
34

45
import jax_cosmo.constants as const
56
from jax_cosmo.scipy.interpolate import interp
@@ -328,14 +329,23 @@ def transverse_comoving_distance(cosmo, a):
328329
\end{matrix}
329330
\right.
330331
"""
331-
chi = radial_comoving_distance(cosmo, a)
332-
if cosmo.k < 0: # Open universe
332+
index = cosmo.k + 1
333+
334+
def open_universe(chi):
333335
return const.rh / cosmo.sqrtk * np.sinh(cosmo.sqrtk * chi / const.rh)
334-
elif cosmo.k > 0: # Closed Universe
335-
return const.rh / cosmo.sqrtk * np.sin(cosmo.sqrtk * chi / const.rh)
336-
else:
336+
337+
def flat_universe(chi):
337338
return chi
338339

340+
def close_universe(chi):
341+
return const.rh / cosmo.sqrtk * np.sin(cosmo.sqrtk * chi / const.rh)
342+
343+
branches = (open_universe, flat_universe, close_universe)
344+
345+
chi = radial_comoving_distance(cosmo, a)
346+
347+
return lax.switch(cosmo.k + 1, branches, chi)
348+
339349

340350
def angular_diameter_distance(cosmo, a):
341351
r"""Angular diameter distance in [Mpc/h] for a given scale factor.

jax_cosmo/core.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,7 @@ def Omega_k(self):
168168

169169
@property
170170
def k(self):
171-
if self.Omega > 1.0: # Closed universe
172-
k = 1.0
173-
elif self.Omega == 1.0: # Flat universe
174-
k = 0
175-
elif self.Omega < 1.0: # Open Universe
176-
k = -1.0
177-
return k
171+
return -np.sign(self._Omega_k).astype(np.int8)
178172

179173
@property
180174
def sqrtk(self):

0 commit comments

Comments
 (0)