Skip to content

EK1 reference implementation #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/test_ek1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Tests for the EK1 implementation."""

import jax.numpy as jnp
from scipy.integrate import solve_ivp

import tornado


def test_reference_ek1_constant_steps():
"""Assert the reference solver returns a similar solution to SciPy.

As long as this test passes, we can test the more efficient solvers against this one here.
"""

ivp = tornado.ivp.vanderpol(t0=0.0, tmax=0.5, stiffness_constant=1.0)
scipy_sol = solve_ivp(ivp.f, t_span=(ivp.t0, ivp.tmax), y0=ivp.y0)
final_t_scipy = scipy_sol.t[-1]
final_y_scipy = scipy_sol.y[:, -1]

dt = jnp.mean(jnp.diff(scipy_sol.t))

steps = tornado.step.ConstantSteps(dt)
ek1 = tornado.ek1.ReferenceEK1(num_derivatives=4, ode_dimension=2, steprule=steps)
sol_gen = ek1.solution_generator(ivp=ivp)
for state in sol_gen:
pass

final_t_ek1 = state.t
final_y_ek1 = ek1.P0 @ state.y.mean
assert jnp.allclose(final_t_scipy, final_t_ek1)
assert jnp.allclose(final_y_scipy, final_y_ek1, rtol=1e-3, atol=1e-3)
2 changes: 1 addition & 1 deletion tornado/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

from jax.config import config

from . import ivp, iwp, odesolver, rv, sqrt, step, taylor_mode
from . import ek1, ivp, iwp, odesolver, rv, sqrt, step, taylor_mode

config.update("jax_enable_x64", True)
79 changes: 79 additions & 0 deletions tornado/ek1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""EK1 solvers."""


import dataclasses

import jax.numpy as jnp

from tornado import ivp, iwp, odesolver, rv, sqrt, taylor_mode


@dataclasses.dataclass
class ODEFilterState:

ivp: "tornado.ivp.InitialValueProblem"
t: float
y: "rv.MultivariateNormal"
error_estimate: jnp.ndarray
reference_state: jnp.ndarray


class ReferenceEK1(odesolver.ODESolver):
"""Naive, reference EK1 implementation. Use this to test against."""

def __init__(self, num_derivatives, ode_dimension, steprule):
super().__init__(steprule=steprule, solver_order=num_derivatives)

# Prior integrated Wiener process
self.iwp = iwp.IntegratedWienerTransition(
num_derivatives=num_derivatives, wiener_process_dimension=ode_dimension
)
self.P0 = self.iwp.make_projection_matrix(0)
self.P1 = self.iwp.make_projection_matrix(1)

# Initialization strategy
self.tm = taylor_mode.TaylorModeInitialization()

def initialize(self, ivp):
initial_rv = self.tm(ivp=ivp, prior=self.iwp)
return ODEFilterState(
ivp=ivp,
t=ivp.t0,
y=initial_rv,
error_estimate=None,
reference_state=None,
)

def attempt_step(self, state, dt):
# Extract system matrices
m, SC = state.y.mean, state.y.cov_cholesky
A, SQ = self.iwp.non_preconditioned_discretize(dt)

# Prediction
m_pred = A @ m
SC_pred = sqrt.propagate_cholesky_factor(A @ SC, SQ)

# Evaluate ODE
t = state.t + dt
m_at = self.P0 @ m_pred
f = state.ivp.f(t, m_at)
J = state.ivp.df(t, m_at)

# Create linearisation
H = self.P1 - J @ self.P0
b = J @ m_at - f

# Update
cov_cholesky, Kgain, sqrt_S = sqrt.update_sqrt(H, SC_pred)
z = H @ m_pred + b
new_mean = m_pred - Kgain @ z
new_rv = rv.MultivariateNormal(new_mean, cov_cholesky)

# Return new state
return ODEFilterState(
ivp=state.ivp,
t=t,
y=new_rv,
error_estimate=None,
reference_state=None,
)