-
Notifications
You must be signed in to change notification settings - Fork 1
JIT the whole perform_step and everything that is lower-level than that #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
888c91e
5667e55
cc4a228
c643ffa
4ece80d
c852984
f0c31d9
351cb66
f705571
60c8a07
f4829c5
d52e5c4
a4fe098
4b2ddd6
4e49c80
ed47abc
01224bb
862def3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ def test_propose_first_dt(): | |
|
||
ivp = tornadox.ivp.vanderpol() | ||
|
||
dt = tornadox.step.propose_first_dt(ivp) | ||
dt = tornadox.step.propose_first_dt(ivp.f, ivp.t0, ivp.y0) | ||
assert dt > 0 | ||
|
||
|
||
|
@@ -48,7 +48,7 @@ def test_error_estimate_is_none(steprule): | |
|
||
@staticmethod | ||
def test_first_dt_is_dt(steprule, ivp, dt): | ||
first_dt = steprule.first_dt(ivp=ivp) | ||
first_dt = steprule.first_dt(*ivp) | ||
assert first_dt == dt | ||
|
||
|
||
|
@@ -123,27 +123,7 @@ def test_scale_error_estimate_2d(steprule, abstol, reltol): | |
) / jnp.sqrt(2) | ||
assert jnp.allclose(E, scaled_error) | ||
|
||
@staticmethod | ||
def test_min_step_exception(steprule): | ||
steprule.min_step = 0.1 | ||
with pytest.raises(ValueError): | ||
steprule.suggest( | ||
previous_dt=1e-1, | ||
scaled_error_estimate=1_000_000_000, | ||
local_convergence_rate=1, | ||
) | ||
|
||
@staticmethod | ||
def test_max_step_exception(steprule): | ||
steprule.max_step = 10.0 | ||
with pytest.raises(ValueError): | ||
steprule.suggest( | ||
previous_dt=9.0, | ||
scaled_error_estimate=1 / 1_000_000_000, | ||
local_convergence_rate=1, | ||
Comment on lines
-127
to
-143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this functionality is removed. I left a comment in the perform_step, because maybe we want something similar back. THe current stuff was just unjittable... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. who needs a safetynet |
||
) | ||
|
||
@staticmethod | ||
def test_first_dt(steprule, ivp): | ||
dt = steprule.first_dt(ivp) | ||
dt = steprule.first_dt(*ivp) | ||
assert dt > 0.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import celluloid | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops, I had this lying aroudn somewhere apparently... well, it is in timings/ so why not leave it in :) |
||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
from matplotlib import cm | ||
from matplotlib.colors import LogNorm | ||
|
||
import tornadox | ||
|
||
bruss = tornadox.ivp.brusselator(N=100) | ||
tol = 1e-8 | ||
ek1 = tornadox.ek1.ReferenceEK1( | ||
steprule=tornadox.step.AdaptiveSteps(abstol=1e-2 * tol, reltol=tol), | ||
initialization=tornadox.init.RungeKutta(method="Radau"), | ||
num_derivatives=4, | ||
) | ||
|
||
ms = [] | ||
ts = [] | ||
|
||
fig, _ = plt.subplots() | ||
|
||
camera = celluloid.Camera(fig) | ||
|
||
for state, _ in ek1.solution_generator(bruss): | ||
ms.append(state.y.mean[0]) | ||
ts.append(state.t) | ||
|
||
t = state.t | ||
|
||
plt.title("LogScale(vmin=1e-10, vmax=1e0, clip=True)") | ||
plt.imshow( | ||
state.y.cov, norm=LogNorm(vmin=1e-10, vmax=1e0, clip="True"), cmap=cm.Greys | ||
) | ||
camera.snap() | ||
|
||
animation = camera.animate() | ||
animation.save("animation.mp4") | ||
|
||
|
||
means = jnp.stack(ms) | ||
ts = jnp.stack(ts) | ||
|
||
plt.subplots(dpi=200) | ||
plt.title(len(ts)) | ||
plt.plot(ts, means) | ||
# plt.show() | ||
# | ||
# fig, ax = plt.subplots(dpi=200) | ||
# stride = 2 | ||
# for t, cov in zip(ts[::stride], covariances[::stride]): | ||
# plt.title(t) | ||
# ax.imshow(cov) | ||
# plt.pause(0.001) | ||
# plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import jax.numpy as jnp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here as above... sorry |
||
import matplotlib.pyplot as plt | ||
|
||
import tornadox | ||
|
||
vdp = tornadox.ivp.vanderpol(stiffness_constant=0.1) | ||
tol = 1e-4 | ||
ek1 = tornadox.ek1.InformationEK1( | ||
steprule=tornadox.step.AdaptiveSteps(abstol=tol, reltol=tol), | ||
initialization=tornadox.init.RungeKutta(method="Radau"), | ||
num_derivatives=2, | ||
) | ||
|
||
ms = [] | ||
ts = [] | ||
for state, _ in ek1.solution_generator(vdp): | ||
ms.append(state.y.mean()[0]) | ||
ts.append(state.t) | ||
|
||
means = jnp.stack(ms) | ||
ts = jnp.stack(ts) | ||
|
||
plt.subplots(dpi=200) | ||
plt.plot(ts, means) | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,11 @@ | ||
import dataclasses | ||
from functools import partial | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jax.scipy.linalg | ||
|
||
import tornadox.iwp | ||
from tornadox import init, ivp, iwp, odefilter, rv, sqrt, step | ||
from tornadox import odefilter, rv, sqrt | ||
from tornadox.ek1 import BatchedEK1 | ||
|
||
|
||
|
@@ -38,24 +37,24 @@ def initialize(self, ivp): | |
mean=extended_dy0, cov_sqrtm=jnp.kron(jnp.eye(ivp.dimension), cov_sqrtm) | ||
) | ||
return odefilter.ODEFilterState( | ||
ivp=ivp, | ||
t=ivp.t0, | ||
y=y, | ||
error_estimate=None, | ||
reference_state=None, | ||
error_estimate=jnp.nan * jnp.ones(self.iwp.wiener_process_dimension), | ||
reference_state=jnp.nan * jnp.ones(self.iwp.wiener_process_dimension), | ||
) | ||
|
||
def attempt_step(self, state, dt, verbose=False): | ||
@partial(jax.jit, static_argnums=(0, 3, 7, 8)) | ||
def attempt_step(self, state, dt, f, t0, tmax, y0, df, df_diagonal): | ||
# [Setup] | ||
m, Cl = state.y.mean.reshape((-1,), order="F"), state.y.cov_sqrtm | ||
A, Ql = self.iwp.non_preconditioned_discretize(dt) | ||
n, d = self.num_derivatives + 1, state.ivp.dimension | ||
n, d = self.num_derivatives + 1, self.iwp.wiener_process_dimension | ||
|
||
# [Predict] | ||
mp = A @ m | ||
|
||
# Measure / calibrate | ||
z = self.E1 @ mp - state.ivp.f(state.t + dt, self.E0 @ mp) | ||
z = self.E1 @ mp - f(state.t + dt, self.E0 @ mp) | ||
H = self.E1 | ||
|
||
S = H @ Ql @ Ql.T @ H.T | ||
|
@@ -73,7 +72,6 @@ def attempt_step(self, state, dt, verbose=False): | |
y_new = jnp.abs(m_new[0]) | ||
|
||
new_state = odefilter.ODEFilterState( | ||
ivp=state.ivp, | ||
t=state.t + dt, | ||
error_estimate=error, | ||
reference_state=y_new, | ||
|
@@ -115,9 +113,8 @@ def initialize(self, ivp): | |
y = rv.MatrixNormal(mean=mean, cov_sqrtm_1=jnp.eye(d), cov_sqrtm_2=cov_sqrtm) | ||
|
||
return odefilter.ODEFilterState( | ||
ivp=ivp, | ||
t=ivp.t0, | ||
error_estimate=None, | ||
error_estimate=jnp.nan, | ||
reference_state=ivp.y0, | ||
y=y, | ||
) | ||
|
@@ -151,12 +148,12 @@ def evaluate_ode(t, f, mp, P, e1): | |
H = e1 @ P | ||
return z, H | ||
|
||
def attempt_step(self, state, dt, verbose=False): | ||
@partial(jax.jit, static_argnums=(0, 3, 7, 8)) | ||
def attempt_step(self, state, dt, f, t0, tmax, y0, df, df_diagonal): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we pass all the unused stuff here, just such that we can call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do it some other time and merge this in now? There are a few PRs in the pipeline now... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure 👍🏼 |
||
# [Setup] | ||
Y = state.y | ||
_m, _Cl = Y.mean, Y.cov_sqrtm_2 | ||
A, Ql = self.A, self.Ql | ||
|
||
t_new = state.t + dt | ||
|
||
# [Preconditioners] | ||
|
@@ -168,7 +165,7 @@ def attempt_step(self, state, dt, verbose=False): | |
mp = A @ m | ||
|
||
# [Measure] | ||
z, H = self.evaluate_ode(t_new, state.ivp.f, mp, P, self.e1) | ||
z, H = self.evaluate_ode(t_new, f, mp, P, self.e1) | ||
|
||
# [Calibration] | ||
sigma_squared, error_estimate = self.compute_sigmasquared_error(P, Ql, z) | ||
|
@@ -186,7 +183,6 @@ def attempt_step(self, state, dt, verbose=False): | |
y_new = jnp.abs(_m_new[0]) | ||
|
||
new_state = odefilter.ODEFilterState( | ||
ivp=state.ivp, | ||
t=t_new, | ||
error_estimate=error_estimate, | ||
reference_state=y_new, | ||
|
Uh oh!
There was an error while loading. Please reload this page.