Skip to content

Commit 8a82d48

Browse files
authored
Merge pull request #18 from DifferentiableUniverseInitiative/kde_nz
Kde nz
2 parents 3d0ad6b + 96dcdfe commit 8a82d48

File tree

3 files changed

+36
-20
lines changed

3 files changed

+36
-20
lines changed

jax_cosmo/angular_cl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def gaussian_cl_covariance(cosmo, ell, probes, f_sky=0.25, return_cls=True):
129129

130130
def get_cov_block(inds):
131131
a, b, c, d = inds
132-
cov = (cl_obs[a]*cl_obs[b] + cl_obs[c]*cl_obs[d])*np.eye(n_ell) / norm
133-
return cov
132+
cov = (cl_obs[a]*cl_obs[b] + cl_obs[c]*cl_obs[d]) / norm
133+
return cov*np.eye(n_ell)
134134

135135
cov_mat = lax.map(get_cov_block, cov_blocks)
136136

jax_cosmo/kernels.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

jax_cosmo/redshift.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,37 @@ class smail_nz(redshift_distribution):
8686
def pz_fn(self, z):
8787
a, b, z0 = self.params
8888
return z**a * np.exp(-(z / z0)**b)
89+
90+
@register_pytree_node_class
91+
class kde_nz(redshift_distribution):
92+
"""
93+
A redshift distribution based on a KDE estimate of the nz of a given catalog
94+
currently uses a Gaussian kernel, TODO: add more if necessary
95+
96+
Parameters:
97+
-----------
98+
zcat: redshift catalog
99+
weights: weight for each galaxy between 0 and 1
100+
101+
Configuration:
102+
--------------
103+
bw: Bandwidth for the KDE
104+
105+
Example:
106+
nz = kde_nz(redshift_catalog, w, bw=0.1)
107+
"""
108+
109+
def _kernel(self, bw, X, x):
110+
"""
111+
Gaussian kernel for KDE
112+
"""
113+
return (1. / np.sqrt(2 * np.pi )/bw) * np.exp(-(X - x)**2 / (bw**2 * 2.))
114+
115+
def pz_fn(self, z):
116+
# Extract parameters
117+
zcat, weight = self.params[:2]
118+
w = np.atleast_1d(weight)
119+
q = np.sum(w)
120+
X = np.expand_dims(zcat, axis=-1)
121+
k = self._kernel(self.config['bw'], X , z)
122+
return np.dot(k.T, w)/(q)

0 commit comments

Comments
 (0)