Skip to content

Removed ode_dimension from ODEFilter constructor #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/test_ek0.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def num_derivatives():

@pytest.fixture
def ek0_solution(ek0_version, num_derivatives, ivp, steps):
ek0 = ek0_version(num_derivatives=num_derivatives, ode_dimension=2, steprule=steps)
ek0 = ek0_version(num_derivatives=num_derivatives, steprule=steps)
sol_gen = ek0.solution_generator(ivp=ivp)
for state in sol_gen:
if state.t > ivp.t0:
Expand Down Expand Up @@ -78,10 +78,10 @@ def test_full_solve_compare_scipy(ek0_solution, scipy_solution):
@pytest.fixture
def solver_tuple(steps, num_derivatives, d):
reference_ek0 = tornado.ek0.ReferenceEK0(
num_derivatives=num_derivatives, ode_dimension=d, steprule=steps
num_derivatives=num_derivatives, steprule=steps
)
kronecker_ek0 = tornado.ek0.KroneckerEK0(
num_derivatives=num_derivatives, ode_dimension=d, steprule=steps
num_derivatives=num_derivatives, steprule=steps
)

return kronecker_ek0, reference_ek0
Expand Down
8 changes: 3 additions & 5 deletions tests/test_ek1.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_full_solve_compare_scipy(
"""Assert the ODEFilter solves an ODE appropriately."""
final_t_scipy, final_y_scipy = scipy_solution

ek1 = ek1_version(num_derivatives=num_derivatives, ode_dimension=2, steprule=steps)
ek1 = ek1_version(num_derivatives=num_derivatives, steprule=steps)
sol_gen = ek1.solution_generator(ivp=ivp)
for state in sol_gen:
if state.t > ivp.t0:
Expand Down Expand Up @@ -111,10 +111,8 @@ def solver_triple(ivp, steps, num_derivatives, approx_solver):
)

d, n = ivp.dimension, num_derivatives
reference_ek1 = tornado.ek1.ReferenceEK1(
num_derivatives=n, ode_dimension=d, steprule=steps
)
ek1_approx = approx_solver(num_derivatives=n, ode_dimension=d, steprule=steps)
reference_ek1 = tornado.ek1.ReferenceEK1(num_derivatives=n, steprule=steps)
ek1_approx = approx_solver(num_derivatives=n, steprule=steps)

return ek1_approx, reference_ek1, ivp

Expand Down
1 change: 0 additions & 1 deletion tests/test_odesolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_odesolver():
constant_steps = tornado.step.ConstantSteps(dt=0.1)
solver_order = 2
solver = EulerAsODEFilter(
ode_dimension=ivp.dimension,
steprule=constant_steps,
num_derivatives=solver_order,
)
Expand Down
26 changes: 25 additions & 1 deletion tornado/ek0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,24 @@

import jax.numpy as jnp

import tornado.iwp
from tornado import init, ivp, iwp, odesolver, rv, sqrt, step


class ReferenceEK0(odesolver.ODEFilter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.P0 = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these lines necessary?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for any functionality, it is rather a stylistic convention to not introduce attributes outside the init.

self.E0 = None
self.E1 = None

def initialize(self, ivp):

self.iwp = tornado.iwp.IntegratedWienerTransition(
num_derivatives=self.num_derivatives,
wiener_process_dimension=ivp.dimension,
)

self.P0 = self.E0 = self.iwp.projection_matrix(0)
self.E1 = self.iwp.projection_matrix(1)

Expand Down Expand Up @@ -65,8 +78,19 @@ def attempt_step(self, state, dt, verbose=False):


class KroneckerEK0(odesolver.ODEFilter):
def initialize(self, ivp):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.P0 = None
self.A = None
self.Ql = None
self.e0 = None
self.e1 = None

def initialize(self, ivp):
self.iwp = tornado.iwp.IntegratedWienerTransition(
num_derivatives=self.num_derivatives,
wiener_process_dimension=ivp.dimension,
)
self.A, self.Ql = self.iwp.preconditioned_discretize_1d

extended_dy0, cov_sqrtm = self.init(
Expand Down
52 changes: 32 additions & 20 deletions tornado/ek1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@
class ReferenceEK1(odesolver.ODEFilter):
"""Naive, reference EK1 implementation. Use this to test against."""

def __init__(self, ode_dimension, steprule, num_derivatives):
super().__init__(
ode_dimension=ode_dimension,
steprule=steprule,
num_derivatives=num_derivatives,
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.P0 = None
self.P1 = None

def initialize(self, ivp):

self.iwp = iwp.IntegratedWienerTransition(
num_derivatives=self.num_derivatives,
wiener_process_dimension=ivp.dimension,
)
self.P0 = self.iwp.projection_matrix(0)
self.P1 = self.iwp.projection_matrix(1)

def initialize(self, ivp):
extended_dy0, cov_sqrtm = self.init(
f=ivp.f,
df=ivp.df,
Expand Down Expand Up @@ -126,21 +130,25 @@ def estimate_error(h, sq, z):
class BatchedEK1(odesolver.ODEFilter):
"""Common functionality for EK1 variations that act on batched multivariate normals."""

def __init__(self, num_derivatives, ode_dimension, steprule):
super().__init__(
ode_dimension=ode_dimension,
steprule=steprule,
num_derivatives=num_derivatives,
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.phi_1d = None
self.sq_1d = None
self.batched_sq = None

def initialize(self, ivp):

self.iwp = iwp.IntegratedWienerTransition(
num_derivatives=self.num_derivatives,
wiener_process_dimension=ivp.dimension,
)
d = self.iwp.wiener_process_dimension
self.phi_1d, self.sq_1d = self.iwp.preconditioned_discretize_1d

# No broadcasting possible here (ad-hoc, that is) bc. jax.vmap expects matching batch sizes
# This can be solved by batching propagate_cholesky_factor differently, but maybe this is not necessary
self.batched_sq = jnp.stack([self.sq_1d] * d)

def initialize(self, ivp):
extended_dy0, cov_sqrtm = self.init(
f=ivp.f,
df=ivp.df,
Expand Down Expand Up @@ -421,21 +429,25 @@ class EarlyTruncationEK1(odesolver.ODEFilter):
(This also means that for the covariance update, we use the inverse of the diagonal of S, not the diagonal of the inverse of S.)
"""

def __init__(self, num_derivatives, ode_dimension, steprule):
super().__init__(
ode_dimension=ode_dimension,
steprule=steprule,
num_derivatives=num_derivatives,
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.P0_1d = None
self.P1_1d = None
self.P0 = None
self.P1 = None

def initialize(self, ivp):

self.iwp = iwp.IntegratedWienerTransition(
num_derivatives=self.num_derivatives, wiener_process_dimension=ivp.dimension
)
self.P0_1d = self.iwp.projection_matrix_1d(0)
self.P1_1d = self.iwp.projection_matrix_1d(1)

d = self.iwp.wiener_process_dimension
self.P0 = linops.BlockDiagonal(jnp.stack([self.P0_1d] * d))
self.P1 = linops.BlockDiagonal(jnp.stack([self.P1_1d] * d))

def initialize(self, ivp):
extended_dy0, cov_sqrtm = self.init(
f=ivp.f,
df=ivp.df,
Expand Down
1 change: 0 additions & 1 deletion tornado/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def solve(
try:
solver = _SOLVER_REGISTRY[method](
num_derivatives=num_derivatives,
ode_dimension=ivp.dimension,
steprule=steprule,
)
except KeyError:
Expand Down
13 changes: 5 additions & 8 deletions tornado/odesolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,14 @@ class ODEFilterState:
class ODEFilter(ABC):
"""Interface for filtering-based ODE solvers in ProbNum."""

def __init__(self, ode_dimension, steprule, num_derivatives, initialization=None):
def __init__(self, steprule, num_derivatives, initialization=None):
self.steprule = steprule
self.num_derivatives = (
num_derivatives # e.g.: RK45 has order=5, IBM(q) has order=q
)

self.num_derivatives = num_derivatives
self.num_steps = 0

# Prior integrated Wiener process
self.iwp = iwp.IntegratedWienerTransition(
num_derivatives=num_derivatives, wiener_process_dimension=ode_dimension
)
# IWP(nu) prior -- will be assembled in initialize()
self.iwp = None

# Initialization strategy
self.init = initialization or init.TaylorMode()
Expand Down