|
1 | 1 | import dataclasses
|
| 2 | +from functools import partial |
2 | 3 |
|
| 4 | +import jax |
3 | 5 | import jax.numpy as jnp
|
| 6 | +import jax.scipy.linalg |
4 | 7 |
|
5 | 8 | import tornadox.iwp
|
6 | 9 | from tornadox import init, ivp, iwp, odefilter, rv, sqrt, step
|
| 10 | +from tornadox.ek1 import BatchedEK1 |
7 | 11 |
|
8 | 12 |
|
9 | 13 | class ReferenceEK0(odefilter.ODEFilter):
|
@@ -173,3 +177,89 @@ def attempt_step(self, state, dt, verbose=False):
|
173 | 177 | )
|
174 | 178 | info_dict = dict(num_f_evaluations=1)
|
175 | 179 | return new_state, info_dict
|
| 180 | + |
| 181 | + |
| 182 | +class DiagonalEK0(BatchedEK1): |
| 183 | + @partial(jax.jit, static_argnums=(0, 1, 2, 3)) |
| 184 | + def attempt_unit_step(self, f, df, df_diagonal, p_1d_raw, m, sc, t): |
| 185 | + m_pred = self.predict_mean(m, phi_1d=self.phi_1d) |
| 186 | + f, z = self.evaluate_ode( |
| 187 | + t=t, f=f, df_diagonal=df_diagonal, p_1d_raw=p_1d_raw, m_pred=m_pred |
| 188 | + ) |
| 189 | + error, sigma = self.estimate_error( |
| 190 | + p_1d_raw=p_1d_raw, |
| 191 | + sq_bd=self.batched_sq, |
| 192 | + z=z, |
| 193 | + ) |
| 194 | + sc_pred = self.predict_cov_sqrtm( |
| 195 | + sc_bd=sc, phi_1d=self.phi_1d, sq_bd=sigma[:, None, None] * self.batched_sq |
| 196 | + ) |
| 197 | + ss, kgain = self.observe_cov_sqrtm(p_1d_raw=p_1d_raw, sc_bd=sc_pred) |
| 198 | + cov_sqrtm = self.correct_cov_sqrtm( |
| 199 | + p_1d_raw=p_1d_raw, |
| 200 | + sc_bd=sc_pred, |
| 201 | + kgain=kgain, |
| 202 | + ) |
| 203 | + new_mean = self.correct_mean(m=m_pred, kgain=kgain, z=z) |
| 204 | + info_dict = dict(num_f_evaluations=1) |
| 205 | + return new_mean, cov_sqrtm, error, info_dict |
| 206 | + |
| 207 | + @staticmethod |
| 208 | + @partial(jax.jit, static_argnums=(1, 2)) |
| 209 | + def evaluate_ode(t, f, df_diagonal, p_1d_raw, m_pred): |
| 210 | + m_pred_no_precon = p_1d_raw[:, None] * m_pred |
| 211 | + m_at = m_pred_no_precon[0] |
| 212 | + fx = f(t, m_at) |
| 213 | + z = m_pred_no_precon[1] - fx |
| 214 | + |
| 215 | + return fx, z |
| 216 | + |
| 217 | + @staticmethod |
| 218 | + @jax.jit |
| 219 | + def estimate_error(p_1d_raw, sq_bd, z): |
| 220 | + |
| 221 | + sq_bd_no_precon = p_1d_raw[None, :, None] * sq_bd # shape (d,n,n) |
| 222 | + sq_bd_no_precon_0 = sq_bd_no_precon[:, 0, :] # shape (d,n) |
| 223 | + sq_bd_no_precon_1 = sq_bd_no_precon[:, 1, :] # shape (d,n) |
| 224 | + h_sq_bd = sq_bd_no_precon_1 # shape (d,n) |
| 225 | + |
| 226 | + s = jnp.einsum("dn,dn->d", h_sq_bd, h_sq_bd) # shape (d,) |
| 227 | + |
| 228 | + xi = z / jnp.sqrt(s) # shape (d,) |
| 229 | + sigma = jnp.abs(xi) # shape (d,) |
| 230 | + error_estimate = sigma * jnp.sqrt(s) # shape (d,) |
| 231 | + |
| 232 | + return error_estimate, sigma |
| 233 | + |
| 234 | + @staticmethod |
| 235 | + @jax.jit |
| 236 | + def observe_cov_sqrtm(p_1d_raw, sc_bd): |
| 237 | + |
| 238 | + sc_bd_no_precon = p_1d_raw[None, :, None] * sc_bd # shape (d,n,n) |
| 239 | + sc_bd_no_precon_0 = sc_bd_no_precon[:, 0, :] # shape (d,n) |
| 240 | + sc_bd_no_precon_1 = sc_bd_no_precon[:, 1, :] # shape (d,n) |
| 241 | + h_sc_bd = sc_bd_no_precon_1 # shape (d,n) |
| 242 | + |
| 243 | + s = jnp.einsum("dn,dn->d", h_sc_bd, h_sc_bd) # shape (d,) |
| 244 | + cross = sc_bd @ h_sc_bd[..., None] # shape (d,n,1) |
| 245 | + kgain = cross / s[..., None, None] # shape (d,n,1) |
| 246 | + |
| 247 | + return jnp.sqrt(s), kgain |
| 248 | + |
| 249 | + @staticmethod |
| 250 | + @jax.jit |
| 251 | + def correct_cov_sqrtm(p_1d_raw, sc_bd, kgain): |
| 252 | + sc_bd_no_precon = p_1d_raw[None, :, None] * sc_bd # shape (d,n,n) |
| 253 | + sc_bd_no_precon_0 = sc_bd_no_precon[:, 0, :] # shape (d,n) |
| 254 | + sc_bd_no_precon_1 = sc_bd_no_precon[:, 1, :] # shape (d,n) |
| 255 | + h_sc_bd = sc_bd_no_precon_1 # shape (d,n) |
| 256 | + kh_sc_bd = kgain @ h_sc_bd[:, None, :] # shape (d,n,n) |
| 257 | + new_sc = sc_bd - kh_sc_bd # shape (d,n,n) |
| 258 | + return new_sc |
| 259 | + |
| 260 | + @staticmethod |
| 261 | + @jax.jit |
| 262 | + def correct_mean(m, kgain, z): |
| 263 | + correction = kgain @ z[:, None, None] # shape (d,n,1) |
| 264 | + new_mean = m - correction[:, :, 0].T # shape (n,d) |
| 265 | + return new_mean |
0 commit comments