Skip to content

Commit 6c24baf

Browse files
authored
Merge pull request #75 from pnkraemer/odefilter-interface
odesolver -> odefilter
2 parents 61a958b + da42a8f commit 6c24baf

File tree

6 files changed

+23
-23
lines changed

6 files changed

+23
-23
lines changed

tests/test_odesolver.py renamed to tests/test_odefilter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class EulerState:
1717
reference_state: jnp.array
1818

1919

20-
class EulerAsODEFilter(tornado.odesolver.ODEFilter):
20+
class EulerAsODEFilter(tornado.odefilter.ODEFilter):
2121
def initialize(self, ivp):
2222
return EulerState(
2323
ivp=ivp, y=ivp.y0, t=ivp.t0, error_estimate=None, reference_state=ivp.y0
@@ -31,15 +31,15 @@ def attempt_step(self, state, dt):
3131
)
3232

3333

34-
def test_odesolver():
34+
def test_odefilter():
3535
ivp = tornado.ivp.vanderpol(t0=0.0, tmax=1.5)
3636
constant_steps = tornado.step.ConstantSteps(dt=0.1)
3737
solver_order = 2
3838
solver = EulerAsODEFilter(
3939
steprule=constant_steps,
4040
num_derivatives=solver_order,
4141
)
42-
assert isinstance(solver, tornado.odesolver.ODEFilter)
42+
assert isinstance(solver, tornado.odefilter.ODEFilter)
4343

4444
gen_sol = solver.solution_generator(ivp)
4545
for idx, _ in enumerate(gen_sol):

tornado/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
iwp,
1212
kalman,
1313
linops,
14-
odesolver,
14+
odefilter,
1515
rv,
1616
sqrt,
1717
step,

tornado/ek0.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import jax.numpy as jnp
44

55
import tornado.iwp
6-
from tornado import init, ivp, iwp, odesolver, rv, sqrt, step
6+
from tornado import init, ivp, iwp, odefilter, rv, sqrt, step
77

88

9-
class ReferenceEK0(odesolver.ODEFilter):
9+
class ReferenceEK0(odefilter.ODEFilter):
1010
def __init__(self, *args, **kwargs):
1111
super().__init__(*args, **kwargs)
1212
self.P0 = None
@@ -34,7 +34,7 @@ def initialize(self, ivp):
3434
y = rv.MultivariateNormal(
3535
mean=mean, cov_sqrtm=jnp.kron(jnp.eye(ivp.dimension), cov_sqrtm)
3636
)
37-
return odesolver.ODEFilterState(
37+
return odefilter.ODEFilterState(
3838
ivp=ivp,
3939
t=ivp.t0,
4040
y=y,
@@ -68,7 +68,7 @@ def attempt_step(self, state, dt, verbose=False):
6868

6969
y_new = jnp.abs(self.E0 @ m_new)
7070

71-
return odesolver.ODEFilterState(
71+
return odefilter.ODEFilterState(
7272
ivp=state.ivp,
7373
t=state.t + dt,
7474
error_estimate=error,
@@ -77,7 +77,7 @@ def attempt_step(self, state, dt, verbose=False):
7777
)
7878

7979

80-
class KroneckerEK0(odesolver.ODEFilter):
80+
class KroneckerEK0(odefilter.ODEFilter):
8181
def __init__(self, *args, **kwargs):
8282
super().__init__(*args, **kwargs)
8383
self.P0 = None
@@ -109,7 +109,7 @@ def initialize(self, ivp):
109109

110110
y = rv.MatrixNormal(mean=mean, cov_sqrtm_1=jnp.eye(d), cov_sqrtm_2=cov_sqrtm)
111111

112-
return odesolver.ODEFilterState(
112+
return odefilter.ODEFilterState(
113113
ivp=ivp,
114114
t=ivp.t0,
115115
error_estimate=None,
@@ -162,7 +162,7 @@ def attempt_step(self, state, dt, verbose=False):
162162
d = z.shape[0]
163163
error_estimate = jnp.stack([jnp.sqrt(sigma_squared * HQH)] * d)
164164

165-
return odesolver.ODEFilterState(
165+
return odefilter.ODEFilterState(
166166
ivp=state.ivp,
167167
t=t_new,
168168
error_estimate=error_estimate,

tornado/ek1.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import jax.numpy as jnp
66
import jax.scipy.linalg
77

8-
from tornado import init, iwp, linops, odesolver, rv, sqrt
8+
from tornado import init, iwp, linops, odefilter, rv, sqrt
99

1010

11-
class ReferenceEK1(odesolver.ODEFilter):
11+
class ReferenceEK1(odefilter.ODEFilter):
1212
"""Naive, reference EK1 implementation. Use this to test against."""
1313

1414
def __init__(self, *args, **kwargs):
@@ -34,7 +34,7 @@ def initialize(self, ivp):
3434
)
3535
mean = extended_dy0.reshape((-1,), order="F")
3636
y = rv.MultivariateNormal(mean, jnp.kron(jnp.eye(ivp.dimension), cov_sqrtm))
37-
return odesolver.ODEFilterState(
37+
return odefilter.ODEFilterState(
3838
ivp=ivp,
3939
t=ivp.t0,
4040
y=y,
@@ -65,7 +65,7 @@ def attempt_step(self, state, dt):
6565
reference_state = jnp.maximum(y1, y2)
6666

6767
# Return new state
68-
return odesolver.ODEFilterState(
68+
return odefilter.ODEFilterState(
6969
ivp=state.ivp,
7070
t=t,
7171
y=new_rv,
@@ -127,7 +127,7 @@ def estimate_error(h, sq, z):
127127
return error_estimate, sigma
128128

129129

130-
class BatchedEK1(odesolver.ODEFilter):
130+
class BatchedEK1(odefilter.ODEFilter):
131131
"""Common functionality for EK1 variations that act on batched multivariate normals."""
132132

133133
def __init__(self, *args, **kwargs):
@@ -159,7 +159,7 @@ def initialize(self, ivp):
159159
d, n = self.iwp.wiener_process_dimension, self.iwp.num_derivatives + 1
160160
cov_sqrtm = jnp.stack([cov_sqrtm] * d)
161161
new_rv = rv.BatchedMultivariateNormal(extended_dy0, cov_sqrtm)
162-
return odesolver.ODEFilterState(
162+
return odefilter.ODEFilterState(
163163
ivp=ivp,
164164
t=ivp.t0,
165165
y=new_rv,
@@ -186,7 +186,7 @@ def attempt_step(self, state, dt):
186186
reference_state = jnp.maximum(y1, y2)
187187

188188
new_rv = rv.BatchedMultivariateNormal(new_mean, cov_sqrtm)
189-
return odesolver.ODEFilterState(
189+
return odefilter.ODEFilterState(
190190
ivp=state.ivp,
191191
t=t,
192192
y=new_rv,
@@ -420,7 +420,7 @@ def correct_cov_sqrtm(p_1d_raw, Jx, sc_bd, kgain):
420420
return new_sc
421421

422422

423-
class EarlyTruncationEK1(odesolver.ODEFilter):
423+
class EarlyTruncationEK1(odefilter.ODEFilter):
424424
"""Use full Jacobians for mean-updates, but truncate cleverly to enforce a block-diagonal posterior covariance.
425425
426426
"Cleverly" means:
@@ -458,7 +458,7 @@ def initialize(self, ivp):
458458
d, n = self.iwp.wiener_process_dimension, self.iwp.num_derivatives + 1
459459
cov_sqrtm = jnp.stack([cov_sqrtm] * d)
460460
new_rv = rv.BatchedMultivariateNormal(extended_dy0, cov_sqrtm)
461-
return odesolver.ODEFilterState(
461+
return odefilter.ODEFilterState(
462462
ivp=ivp,
463463
t=ivp.t0,
464464
y=new_rv,
@@ -616,7 +616,7 @@ def attempt_step(self, state, dt):
616616

617617
# Return new state
618618
new_rv = rv.BatchedMultivariateNormal(new_mean, cov_sqrtm)
619-
return odesolver.ODEFilterState(
619+
return odefilter.ODEFilterState(
620620
ivp=state.ivp,
621621
t=t,
622622
y=new_rv,

tornado/ivpsolve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import jax.numpy as jnp
66

7-
from tornado import ek0, ek1, ivp, odesolver, rv, step
7+
from tornado import ek0, ek1, ivp, odefilter, rv, step
88

99
# Will be extended in the dev process
10-
_SOLVER_REGISTRY: Dict[str, odesolver.ODEFilter] = {
10+
_SOLVER_REGISTRY: Dict[str, odefilter.ODEFilter] = {
1111
"ek1_reference": ek1.ReferenceEK1,
1212
"ek1_diagonal": ek1.DiagonalEK1,
1313
"ek1_truncation": ek1.TruncationEK1,
File renamed without changes.

0 commit comments

Comments
 (0)