Skip to content

Commit 3f3e5c2

Browse files
Merge pull request #90 from pnkraemer/DiagonalEK0
Implement `DiagonalEK0`, an EK0 with vector-valued diffusion
2 parents 8782352 + 1d8ef0e commit 3f3e5c2

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

tests/test_ek0.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def scipy_solution(ivp):
5151
EK0_VERSIONS = [
5252
tornadox.ek0.ReferenceEK0,
5353
tornadox.ek0.KroneckerEK0,
54+
tornadox.ek0.DiagonalEK0,
5455
]
5556
all_ek0_versions = pytest.mark.parametrize("ek0_version", EK0_VERSIONS)
5657

tornadox/ek0.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import dataclasses
2+
from functools import partial
23

4+
import jax
35
import jax.numpy as jnp
6+
import jax.scipy.linalg
47

58
import tornadox.iwp
69
from tornadox import init, ivp, iwp, odefilter, rv, sqrt, step
10+
from tornadox.ek1 import BatchedEK1
711

812

913
class ReferenceEK0(odefilter.ODEFilter):
@@ -173,3 +177,89 @@ def attempt_step(self, state, dt, verbose=False):
173177
)
174178
info_dict = dict(num_f_evaluations=1)
175179
return new_state, info_dict
180+
181+
182+
class DiagonalEK0(BatchedEK1):
183+
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
184+
def attempt_unit_step(self, f, df, df_diagonal, p_1d_raw, m, sc, t):
185+
m_pred = self.predict_mean(m, phi_1d=self.phi_1d)
186+
f, z = self.evaluate_ode(
187+
t=t, f=f, df_diagonal=df_diagonal, p_1d_raw=p_1d_raw, m_pred=m_pred
188+
)
189+
error, sigma = self.estimate_error(
190+
p_1d_raw=p_1d_raw,
191+
sq_bd=self.batched_sq,
192+
z=z,
193+
)
194+
sc_pred = self.predict_cov_sqrtm(
195+
sc_bd=sc, phi_1d=self.phi_1d, sq_bd=sigma[:, None, None] * self.batched_sq
196+
)
197+
ss, kgain = self.observe_cov_sqrtm(p_1d_raw=p_1d_raw, sc_bd=sc_pred)
198+
cov_sqrtm = self.correct_cov_sqrtm(
199+
p_1d_raw=p_1d_raw,
200+
sc_bd=sc_pred,
201+
kgain=kgain,
202+
)
203+
new_mean = self.correct_mean(m=m_pred, kgain=kgain, z=z)
204+
info_dict = dict(num_f_evaluations=1)
205+
return new_mean, cov_sqrtm, error, info_dict
206+
207+
@staticmethod
208+
@partial(jax.jit, static_argnums=(1, 2))
209+
def evaluate_ode(t, f, df_diagonal, p_1d_raw, m_pred):
210+
m_pred_no_precon = p_1d_raw[:, None] * m_pred
211+
m_at = m_pred_no_precon[0]
212+
fx = f(t, m_at)
213+
z = m_pred_no_precon[1] - fx
214+
215+
return fx, z
216+
217+
@staticmethod
218+
@jax.jit
219+
def estimate_error(p_1d_raw, sq_bd, z):
220+
221+
sq_bd_no_precon = p_1d_raw[None, :, None] * sq_bd # shape (d,n,n)
222+
sq_bd_no_precon_0 = sq_bd_no_precon[:, 0, :] # shape (d,n)
223+
sq_bd_no_precon_1 = sq_bd_no_precon[:, 1, :] # shape (d,n)
224+
h_sq_bd = sq_bd_no_precon_1 # shape (d,n)
225+
226+
s = jnp.einsum("dn,dn->d", h_sq_bd, h_sq_bd) # shape (d,)
227+
228+
xi = z / jnp.sqrt(s) # shape (d,)
229+
sigma = jnp.abs(xi) # shape (d,)
230+
error_estimate = sigma * jnp.sqrt(s) # shape (d,)
231+
232+
return error_estimate, sigma
233+
234+
@staticmethod
235+
@jax.jit
236+
def observe_cov_sqrtm(p_1d_raw, sc_bd):
237+
238+
sc_bd_no_precon = p_1d_raw[None, :, None] * sc_bd # shape (d,n,n)
239+
sc_bd_no_precon_0 = sc_bd_no_precon[:, 0, :] # shape (d,n)
240+
sc_bd_no_precon_1 = sc_bd_no_precon[:, 1, :] # shape (d,n)
241+
h_sc_bd = sc_bd_no_precon_1 # shape (d,n)
242+
243+
s = jnp.einsum("dn,dn->d", h_sc_bd, h_sc_bd) # shape (d,)
244+
cross = sc_bd @ h_sc_bd[..., None] # shape (d,n,1)
245+
kgain = cross / s[..., None, None] # shape (d,n,1)
246+
247+
return jnp.sqrt(s), kgain
248+
249+
@staticmethod
250+
@jax.jit
251+
def correct_cov_sqrtm(p_1d_raw, sc_bd, kgain):
252+
sc_bd_no_precon = p_1d_raw[None, :, None] * sc_bd # shape (d,n,n)
253+
sc_bd_no_precon_0 = sc_bd_no_precon[:, 0, :] # shape (d,n)
254+
sc_bd_no_precon_1 = sc_bd_no_precon[:, 1, :] # shape (d,n)
255+
h_sc_bd = sc_bd_no_precon_1 # shape (d,n)
256+
kh_sc_bd = kgain @ h_sc_bd[:, None, :] # shape (d,n,n)
257+
new_sc = sc_bd - kh_sc_bd # shape (d,n,n)
258+
return new_sc
259+
260+
@staticmethod
261+
@jax.jit
262+
def correct_mean(m, kgain, z):
263+
correction = kgain @ z[:, None, None] # shape (d,n,1)
264+
new_mean = m - correction[:, :, 0].T # shape (n,d)
265+
return new_mean

0 commit comments

Comments
 (0)