Skip to content

JIT the whole perform_step and everything that is lower-level than that #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/test_ek0.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
Expand Down Expand Up @@ -29,7 +30,7 @@ def num_derivatives():
@pytest.fixture
def ek0_solution(ek0_version, num_derivatives, ivp, steps):
ek0 = ek0_version(num_derivatives=num_derivatives, steprule=steps)
state, _ = ek0.simulate_final_state(ivp=ivp)
state, info = ek0.simulate_final_state(ivp=ivp)

final_t_ek0 = state.t
final_y_ek0 = state.y.mean[0]
Expand Down Expand Up @@ -112,8 +113,8 @@ def stepped_both(solver_tuple, ivp, initialized_both):
kronecker_ek0, reference_ek0 = solver_tuple
kronecker_init, reference_init = initialized_both

kronecker_stepped, _ = kronecker_ek0.attempt_step(state=kronecker_init, dt=0.12345)
reference_stepped, _ = reference_ek0.attempt_step(state=reference_init, dt=0.12345)
kronecker_stepped, _ = kronecker_ek0.attempt_step(kronecker_init, 0.12345, *ivp)
reference_stepped, _ = reference_ek0.attempt_step(reference_init, 0.12345, *ivp)

return kronecker_stepped, reference_stepped

Expand Down
6 changes: 3 additions & 3 deletions tests/test_ek1.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def approx_initialized(solver_triple):
@pytest.fixture
def approx_stepped(solver_triple, approx_initialized, dt):
"""Attempt a step with the to-be-tested-EK1 and the reference EK1."""
ek1_approx, reference_ek1, _ = solver_triple
ek1_approx, reference_ek1, ivp = solver_triple
init_ref, init_approx = approx_initialized

step_ref, _ = reference_ek1.attempt_step(state=init_ref, dt=dt)
step_approx, _ = ek1_approx.attempt_step(state=init_approx, dt=dt)
step_ref, _ = reference_ek1.attempt_step(init_ref, dt, *ivp)
step_approx, _ = ek1_approx.attempt_step(init_approx, dt, *ivp)

return step_ref, step_approx

Expand Down
12 changes: 0 additions & 12 deletions tests/test_iwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,6 @@ def test_projection_matrix_num_nonzeros(projection_matrix, iwp):
assert (projection_matrix == 1).sum() == iwp.wiener_process_dimension


# Tests for the projection operator


@pytest.fixture
def projection_operator(iwp):
return iwp.projection_operator_1d(0)


def test_projection_operator(projection_operator):
assert isinstance(projection_operator, tornadox.linops.DerivativeSelection)


def test_reorder_states():
# Transition handles reordering
iwp = tornadox.iwp.IntegratedWienerTransition(
Expand Down
35 changes: 22 additions & 13 deletions tests/test_odefilter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
"""Tests for ODEFilter interfaces."""

from collections import namedtuple

import dataclasses

import jax
import jax.numpy as jnp
import pytest

import tornadox


@dataclasses.dataclass
class EulerState:
ivp: tornadox.ivp.InitialValueProblem
y: jnp.array
t: float
error_estimate: jnp.array
reference_state: jnp.array
class EulerState(namedtuple("_EulerState", "t y error_estimate reference_state")):
pass


class EulerAsODEFilter(tornadox.odefilter.ODEFilter):
Expand All @@ -24,17 +19,17 @@ def initialize(self, ivp):
ivp.y0, cov_sqrtm=jnp.zeros((ivp.y0.shape[0], ivp.y0.shape[0]))
)
return EulerState(
ivp=ivp, y=y, t=ivp.t0, error_estimate=None, reference_state=ivp.y0
y=y, t=ivp.t0, error_estimate=jnp.nan * ivp.y0, reference_state=ivp.y0
)

def attempt_step(self, state, dt):
y = state.y.mean + dt * state.ivp.f(state.t, state.y.mean)
def attempt_step(self, state, dt, f, t0, tmax, y0, df, df_diagonal):
y = state.y.mean + dt * f(state.t, state.y.mean)
t = state.t + dt
y = tornadox.rv.MultivariateNormal(
y, cov_sqrtm=jnp.zeros((y.shape[0], y.shape[0]))
)
new_state = EulerState(
ivp=state.ivp, y=y, t=t, error_estimate=None, reference_state=y
y=y, t=t, error_estimate=jnp.nan * y0, reference_state=y.mean
)
return new_state, {}

Expand Down Expand Up @@ -78,3 +73,17 @@ def locations():
def test_solve_stop_at(ivp, solver, locations):
sol = solver.solve(ivp, stop_at=locations)
assert jnp.isin(locations[0], jnp.array(sol.t))


def test_odefilter_state_jittable(ivp):
def fun(state):
t, y, err, ref = state
return tornadox.odefilter.ODEFilterState(t, y, err, ref)

fun_jitted = jax.jit(fun)
x = jnp.zeros(3)
state = tornadox.odefilter.ODEFilterState(
t=0, y=x, error_estimate=x, reference_state=x
)
out = fun_jitted(state)
assert type(out) == type(state)
26 changes: 3 additions & 23 deletions tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_propose_first_dt():

ivp = tornadox.ivp.vanderpol()

dt = tornadox.step.propose_first_dt(ivp)
dt = tornadox.step.propose_first_dt(ivp.f, ivp.t0, ivp.y0)
assert dt > 0


Expand Down Expand Up @@ -48,7 +48,7 @@ def test_error_estimate_is_none(steprule):

@staticmethod
def test_first_dt_is_dt(steprule, ivp, dt):
first_dt = steprule.first_dt(ivp=ivp)
first_dt = steprule.first_dt(*ivp)
assert first_dt == dt


Expand Down Expand Up @@ -123,27 +123,7 @@ def test_scale_error_estimate_2d(steprule, abstol, reltol):
) / jnp.sqrt(2)
assert jnp.allclose(E, scaled_error)

@staticmethod
def test_min_step_exception(steprule):
steprule.min_step = 0.1
with pytest.raises(ValueError):
steprule.suggest(
previous_dt=1e-1,
scaled_error_estimate=1_000_000_000,
local_convergence_rate=1,
)

@staticmethod
def test_max_step_exception(steprule):
steprule.max_step = 10.0
with pytest.raises(ValueError):
steprule.suggest(
previous_dt=9.0,
scaled_error_estimate=1 / 1_000_000_000,
local_convergence_rate=1,
Comment on lines -127 to -143
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this functionality is removed. I left a comment in the perform_step, because maybe we want something similar back. THe current stuff was just unjittable...

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

who needs a safetynet

)

@staticmethod
def test_first_dt(steprule, ivp):
dt = steprule.first_dt(ivp)
dt = steprule.first_dt(*ivp)
assert dt > 0.0
54 changes: 54 additions & 0 deletions timings/diagonal_ek1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import celluloid
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, I had this lying aroudn somewhere apparently... well, it is in timings/ so why not leave it in :)

import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import LogNorm

import tornadox

bruss = tornadox.ivp.brusselator(N=100)
tol = 1e-8
ek1 = tornadox.ek1.ReferenceEK1(
steprule=tornadox.step.AdaptiveSteps(abstol=1e-2 * tol, reltol=tol),
initialization=tornadox.init.RungeKutta(method="Radau"),
num_derivatives=4,
)

ms = []
ts = []

fig, _ = plt.subplots()

camera = celluloid.Camera(fig)

for state, _ in ek1.solution_generator(bruss):
ms.append(state.y.mean[0])
ts.append(state.t)

t = state.t

plt.title("LogScale(vmin=1e-10, vmax=1e0, clip=True)")
plt.imshow(
state.y.cov, norm=LogNorm(vmin=1e-10, vmax=1e0, clip="True"), cmap=cm.Greys
)
camera.snap()

animation = camera.animate()
animation.save("animation.mp4")


means = jnp.stack(ms)
ts = jnp.stack(ts)

plt.subplots(dpi=200)
plt.title(len(ts))
plt.plot(ts, means)
# plt.show()
#
# fig, ax = plt.subplots(dpi=200)
# stride = 2
# for t, cov in zip(ts[::stride], covariances[::stride]):
# plt.title(t)
# ax.imshow(cov)
# plt.pause(0.001)
# plt.show()
25 changes: 25 additions & 0 deletions timings/information.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import jax.numpy as jnp
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here as above... sorry

import matplotlib.pyplot as plt

import tornadox

vdp = tornadox.ivp.vanderpol(stiffness_constant=0.1)
tol = 1e-4
ek1 = tornadox.ek1.InformationEK1(
steprule=tornadox.step.AdaptiveSteps(abstol=tol, reltol=tol),
initialization=tornadox.init.RungeKutta(method="Radau"),
num_derivatives=2,
)

ms = []
ts = []
for state, _ in ek1.solution_generator(vdp):
ms.append(state.y.mean()[0])
ts.append(state.t)

means = jnp.stack(ms)
ts = jnp.stack(ts)

plt.subplots(dpi=200)
plt.plot(ts, means)
plt.show()
26 changes: 11 additions & 15 deletions tornadox/ek0.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import dataclasses
from functools import partial

import jax
import jax.numpy as jnp
import jax.scipy.linalg

import tornadox.iwp
from tornadox import init, ivp, iwp, odefilter, rv, sqrt, step
from tornadox import odefilter, rv, sqrt
from tornadox.ek1 import BatchedEK1


Expand Down Expand Up @@ -38,24 +37,24 @@ def initialize(self, ivp):
mean=extended_dy0, cov_sqrtm=jnp.kron(jnp.eye(ivp.dimension), cov_sqrtm)
)
return odefilter.ODEFilterState(
ivp=ivp,
t=ivp.t0,
y=y,
error_estimate=None,
reference_state=None,
error_estimate=jnp.nan * jnp.ones(self.iwp.wiener_process_dimension),
reference_state=jnp.nan * jnp.ones(self.iwp.wiener_process_dimension),
)

def attempt_step(self, state, dt, verbose=False):
@partial(jax.jit, static_argnums=(0, 3, 7, 8))
def attempt_step(self, state, dt, f, t0, tmax, y0, df, df_diagonal):
# [Setup]
m, Cl = state.y.mean.reshape((-1,), order="F"), state.y.cov_sqrtm
A, Ql = self.iwp.non_preconditioned_discretize(dt)
n, d = self.num_derivatives + 1, state.ivp.dimension
n, d = self.num_derivatives + 1, self.iwp.wiener_process_dimension

# [Predict]
mp = A @ m

# Measure / calibrate
z = self.E1 @ mp - state.ivp.f(state.t + dt, self.E0 @ mp)
z = self.E1 @ mp - f(state.t + dt, self.E0 @ mp)
H = self.E1

S = H @ Ql @ Ql.T @ H.T
Expand All @@ -73,7 +72,6 @@ def attempt_step(self, state, dt, verbose=False):
y_new = jnp.abs(m_new[0])

new_state = odefilter.ODEFilterState(
ivp=state.ivp,
t=state.t + dt,
error_estimate=error,
reference_state=y_new,
Expand Down Expand Up @@ -115,9 +113,8 @@ def initialize(self, ivp):
y = rv.MatrixNormal(mean=mean, cov_sqrtm_1=jnp.eye(d), cov_sqrtm_2=cov_sqrtm)

return odefilter.ODEFilterState(
ivp=ivp,
t=ivp.t0,
error_estimate=None,
error_estimate=jnp.nan,
reference_state=ivp.y0,
y=y,
)
Expand Down Expand Up @@ -151,12 +148,12 @@ def evaluate_ode(t, f, mp, P, e1):
H = e1 @ P
return z, H

def attempt_step(self, state, dt, verbose=False):
@partial(jax.jit, static_argnums=(0, 3, 7, 8))
def attempt_step(self, state, dt, f, t0, tmax, y0, df, df_diagonal):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we pass all the unused stuff here, just such that we can call attempt_step(..., *ivp)? If so, I vote in favor for a long-but-explicit alternative (i.e., pass (keyword-)args explicitly) to avoid annoying bugs

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do it some other time and merge this in now? There are a few PRs in the pipeline now...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure 👍🏼

# [Setup]
Y = state.y
_m, _Cl = Y.mean, Y.cov_sqrtm_2
A, Ql = self.A, self.Ql

t_new = state.t + dt

# [Preconditioners]
Expand All @@ -168,7 +165,7 @@ def attempt_step(self, state, dt, verbose=False):
mp = A @ m

# [Measure]
z, H = self.evaluate_ode(t_new, state.ivp.f, mp, P, self.e1)
z, H = self.evaluate_ode(t_new, f, mp, P, self.e1)

# [Calibration]
sigma_squared, error_estimate = self.compute_sigmasquared_error(P, Ql, z)
Expand All @@ -186,7 +183,6 @@ def attempt_step(self, state, dt, verbose=False):
y_new = jnp.abs(_m_new[0])

new_state = odefilter.ODEFilterState(
ivp=state.ivp,
t=t_new,
error_estimate=error_estimate,
reference_state=y_new,
Expand Down
Loading