File tree Expand file tree Collapse file tree 2 files changed +16
-12
lines changed Expand file tree Collapse file tree 2 files changed +16
-12
lines changed Original file line number Diff line number Diff line change 11# This module implements various functions for the background COSMOLOGY
22import jax .numpy as np
3+ from jax import lax
34
45import jax_cosmo .constants as const
56from jax_cosmo .scipy .interpolate import interp
@@ -328,14 +329,23 @@ def transverse_comoving_distance(cosmo, a):
328329 \end{matrix}
329330 \right.
330331 """
331- chi = radial_comoving_distance (cosmo , a )
332- if cosmo .k < 0 : # Open universe
332+ index = cosmo .k + 1
333+
334+ def open_universe (chi ):
333335 return const .rh / cosmo .sqrtk * np .sinh (cosmo .sqrtk * chi / const .rh )
334- elif cosmo .k > 0 : # Closed Universe
335- return const .rh / cosmo .sqrtk * np .sin (cosmo .sqrtk * chi / const .rh )
336- else :
336+
337+ def flat_universe (chi ):
337338 return chi
338339
340+ def close_universe (chi ):
341+ return const .rh / cosmo .sqrtk * np .sin (cosmo .sqrtk * chi / const .rh )
342+
343+ branches = (open_universe , flat_universe , close_universe )
344+
345+ chi = radial_comoving_distance (cosmo , a )
346+
347+ return lax .switch (cosmo .k + 1 , branches , chi )
348+
339349
340350def angular_diameter_distance (cosmo , a ):
341351 r"""Angular diameter distance in [Mpc/h] for a given scale factor.
Original file line number Diff line number Diff line change @@ -168,13 +168,7 @@ def Omega_k(self):
168168
169169 @property
170170 def k (self ):
171- if self .Omega > 1.0 : # Closed universe
172- k = 1.0
173- elif self .Omega == 1.0 : # Flat universe
174- k = 0
175- elif self .Omega < 1.0 : # Open Universe
176- k = - 1.0
177- return k
171+ return - np .sign (self ._Omega_k ).astype (np .int8 )
178172
179173 @property
180174 def sqrtk (self ):
You can’t perform that action at this time.
0 commit comments