Skip to content

Diagonal EK1 #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e7ac258
initialize ek1 diagonal tested
pnkraemer Aug 14, 2021
1fb3674
failing step test for attempt step
pnkraemer Aug 14, 2021
42af06e
transposed linop
pnkraemer Aug 14, 2021
105b454
Merge branch 'main' of github.com:pnkraemer/tornado into nk/diagonal-ek1
pnkraemer Aug 14, 2021
9f4cf49
subtraction of block matrices
pnkraemer Aug 14, 2021
c128994
diagonal ek1 runs through, but error remains large
pnkraemer Aug 14, 2021
d137df3
reference implementation with preconditioning
pnkraemer Aug 14, 2021
0e5501b
preconditioning in diagonal ek1
pnkraemer Aug 14, 2021
182cfa0
some docs -- need a better QR now
pnkraemer Aug 14, 2021
d120023
Merge branch 'main' of github.com:pnkraemer/tornado into nk/diagonal-ek1
pnkraemer Aug 14, 2021
215c299
batched prediction step in sqrt
pnkraemer Aug 14, 2021
4727298
optional S2 in batched sqrt prediction
pnkraemer Aug 14, 2021
cdd8a37
removed commented code
pnkraemer Aug 14, 2021
d275fa6
reproduced ek1 error in backward sqrt
pnkraemer Aug 14, 2021
60d4318
realised the bug was not a bug
pnkraemer Aug 14, 2021
09a1fb3
ek1 uses batched version things now
pnkraemer Aug 14, 2021
0ae9d41
replace batch predict with vmap
pnkraemer Aug 14, 2021
836f00d
removed optional argument inputs
pnkraemer Aug 15, 2021
9f288d4
wrote failing test for error estimation
pnkraemer Aug 15, 2021
7aa9045
error estimation
pnkraemer Aug 15, 2021
9244146
full solve adaptive ek1
pnkraemer Aug 15, 2021
8929529
fixed brainfart
pnkraemer Aug 15, 2021
98888e0
comment
pnkraemer Aug 15, 2021
77352fc
renamed batched propagate
pnkraemer Aug 15, 2021
15a9273
sqrtm-to-cholesky
pnkraemer Aug 15, 2021
0c7316f
batched sqrtm to cholesky
pnkraemer Aug 15, 2021
53d524b
introduced fixtures to sqrt test
pnkraemer Aug 15, 2021
d6066bd
sqrt update simplified
pnkraemer Aug 15, 2021
294dba0
batched updated test simplified
pnkraemer Aug 15, 2021
09a80f7
docs in test sqrt
pnkraemer Aug 15, 2021
fde0074
eliminated a QR decomp from EK1
pnkraemer Aug 15, 2021
f01f151
Updated comments
pnkraemer Aug 15, 2021
9860a35
cov_cholesky is cov_sqrtm
pnkraemer Aug 15, 2021
e5f7ebb
no redundant qr anymore
pnkraemer Aug 15, 2021
1b5aac5
no error estimation for constant steps
pnkraemer Aug 15, 2021
e721fc0
fixed tests
pnkraemer Aug 15, 2021
3ce66b9
removed qr call
Aug 16, 2021
c770e02
comments while reviewing
Aug 16, 2021
0982307
fixed conflicts
pnkraemer Aug 16, 2021
13dbae2
Proper dynamic calibration in EK1
pnkraemer Aug 16, 2021
5371ad5
simplified test
pnkraemer Aug 16, 2021
3ef878f
some check
pnkraemer Aug 16, 2021
fd1f2cc
kicked tril_to_positive_tril and fixed test
pnkraemer Aug 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion tests/test_ek1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for the EK1 implementation."""

import dataclasses

import jax.numpy as jnp
from scipy.integrate import solve_ivp

Expand All @@ -12,7 +14,7 @@ def test_reference_ek1_constant_steps():
As long as this test passes, we can test the more efficient solvers against this one here.
"""

ivp = tornado.ivp.vanderpol(t0=0.0, tmax=0.5, stiffness_constant=1.0)
ivp = tornado.ivp.vanderpol(t0=0.0, tmax=0.25, stiffness_constant=1.0)
scipy_sol = solve_ivp(ivp.f, t_span=(ivp.t0, ivp.tmax), y0=ivp.y0)
final_t_scipy = scipy_sol.t[-1]
final_y_scipy = scipy_sol.y[:, -1]
Expand All @@ -29,3 +31,100 @@ def test_reference_ek1_constant_steps():
final_y_ek1 = ek1.P0 @ state.y.mean
assert jnp.allclose(final_t_scipy, final_t_ek1)
assert jnp.allclose(final_y_scipy, final_y_ek1, rtol=1e-3, atol=1e-3)


def test_diagonal_ek1_constant_steps():
# only "constant steps", because there is no error estimation yet.
old_ivp = tornado.ivp.vanderpol(t0=0.0, tmax=0.5, stiffness_constant=1.0)

# Diagonal Jacobian
new_df = lambda t, y: jnp.diag(jnp.diag(old_ivp.df(t, y)))
ivp = tornado.ivp.InitialValueProblem(
f=old_ivp.f,
df=new_df,
t0=old_ivp.t0,
tmax=old_ivp.tmax,
y0=old_ivp.y0,
)

steps = tornado.step.ConstantSteps(0.1)
reference_ek1 = tornado.ek1.ReferenceEK1(
num_derivatives=4, ode_dimension=2, steprule=steps
)
diagonal_ek1 = tornado.ek1.DiagonalEK1(
num_derivatives=4, ode_dimension=2, steprule=steps
)

# Initialize works as expected
init_ref = reference_ek1.initialize(ivp=ivp)
init_diag = diagonal_ek1.initialize(ivp=ivp)
assert jnp.allclose(init_diag.t, init_ref.t)
assert jnp.allclose(init_diag.y.mean, init_ref.y.mean)
assert isinstance(init_diag.y.cov_sqrtm, tornado.linops.BlockDiagonal)
assert jnp.allclose(init_diag.y.cov_sqrtm.todense(), init_ref.y.cov_sqrtm)

# Attempt step works as expected
step_ref = reference_ek1.attempt_step(state=init_ref, dt=0.12345)
step_diag = diagonal_ek1.attempt_step(state=init_diag, dt=0.12345)
assert jnp.allclose(init_diag.t, init_ref.t)
assert jnp.allclose(step_diag.y.mean, step_ref.y.mean)
assert isinstance(step_diag.y.cov_sqrtm, tornado.linops.BlockDiagonal)
received = (step_diag.y.cov_sqrtm @ step_diag.y.cov_sqrtm.T).todense()
expected = step_ref.y.cov_sqrtm @ step_ref.y.cov_sqrtm.T
assert received.shape == expected.shape
assert jnp.allclose(received, expected), received - expected


def test_diagonal_ek1_adaptive_steps():
"""Error estimation is only computed for adaptive steps. This test computes the result of attempt_step()."""
old_ivp = tornado.ivp.vanderpol(t0=0.0, tmax=0.5, stiffness_constant=1.0)

# Diagonal Jacobian
new_df = lambda t, y: jnp.diag(jnp.diag(old_ivp.df(t, y)))
ivp = tornado.ivp.InitialValueProblem(
f=old_ivp.f,
df=new_df,
t0=old_ivp.t0,
tmax=old_ivp.tmax,
y0=old_ivp.y0,
)

steps = tornado.step.AdaptiveSteps(0.1, abstol=1e-1, reltol=1e-1)
diagonal_ek1 = tornado.ek1.DiagonalEK1(
num_derivatives=4, ode_dimension=2, steprule=steps
)
init_diag = diagonal_ek1.initialize(ivp=ivp)
assert isinstance(init_diag.y.cov_sqrtm, tornado.linops.BlockDiagonal)

# Attempt step works as expected
d = diagonal_ek1.iwp.wiener_process_dimension
n = diagonal_ek1.iwp.num_derivatives
step_diag = diagonal_ek1.attempt_step(state=init_diag, dt=0.12345)
assert isinstance(step_diag.y.cov_sqrtm, tornado.linops.BlockDiagonal)
assert isinstance(step_diag.y.mean, jnp.ndarray)
assert isinstance(step_diag.reference_state, jnp.ndarray)
assert isinstance(step_diag.error_estimate, jnp.ndarray)
assert step_diag.y.mean.shape == (d * (n + 1),)
assert step_diag.reference_state.shape == (d,)
assert step_diag.error_estimate.shape == (d,)
assert jnp.all(step_diag.reference_state >= 0)


def test_diagonal_ek1_adaptive_steps_full_solve():

ivp = tornado.ivp.vanderpol(t0=0.0, tmax=0.25, stiffness_constant=1.0)
scipy_sol = solve_ivp(ivp.f, t_span=(ivp.t0, ivp.tmax), y0=ivp.y0)
final_t_scipy = scipy_sol.t[-1]
final_y_scipy = scipy_sol.y[:, -1]

dt = jnp.mean(jnp.diff(scipy_sol.t))
steps = tornado.step.AdaptiveSteps(first_dt=dt, abstol=1e-3, reltol=1e-3)
ek1 = tornado.ek1.DiagonalEK1(num_derivatives=4, ode_dimension=2, steprule=steps)
sol_gen = ek1.solution_generator(ivp=ivp)
for state in sol_gen:
pass

final_t_ek1 = state.t
final_y_ek1 = ek1.P0 @ state.y.mean
assert jnp.allclose(final_t_scipy, final_t_ek1)
assert jnp.allclose(final_y_scipy, final_y_ek1, rtol=1e-3, atol=1e-3)
17 changes: 17 additions & 0 deletions tests/test_linops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def test_sum_block_diagonals(A, B):
assert jnp.allclose(new.todense(), expected)


def test_diff_block_diagonals(A, B):
B1 = tornado.linops.BlockDiagonal.from_arrays(A, B)
B2 = tornado.linops.BlockDiagonal.from_arrays(B, A)
new = B1 - B2
expected = B1.todense() - B2.todense()
assert isinstance(new, tornado.linops.BlockDiagonal)
assert jnp.allclose(new.todense(), expected)


def test_transpose_block_diagonals(A, B):
BD = tornado.linops.BlockDiagonal.from_arrays(A, B)
new = BD.T
expected = BD.todense().T
assert isinstance(new, tornado.linops.BlockDiagonal)
assert jnp.allclose(new.todense(), expected)


@pytest.fixture
def P0():
return tornado.linops.DerivativeSelection(derivative=0)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_rv():
"""Random variables work as expected."""
mean = jax.numpy.array([1.0, 2.0])
cov_cholesky = jax.numpy.array([[1.0, 0.0], [1.0, 1.0]])
normal = tornado.rv.MultivariateNormal(mean=mean, cov_cholesky=cov_cholesky)
normal = tornado.rv.MultivariateNormal(mean=mean, cov_sqrtm=cov_cholesky)

assert isinstance(normal, tornado.rv.MultivariateNormal)
assert jax.numpy.allclose(normal.cov, cov_cholesky @ cov_cholesky.T)
176 changes: 121 additions & 55 deletions tests/test_sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,76 +8,142 @@

@pytest.fixture
def iwp():
"""Steal system matrices from an IWP transition."""
return tornado.iwp.IntegratedWienerTransition(
wiener_process_dimension=1, num_derivatives=2
wiener_process_dimension=1, num_derivatives=1
)


def test_propagate_cholesky_factor(iwp):
transition_matrix, process_noise_cholesky = iwp.preconditioned_discretize_1d
@pytest.fixture
def H_and_SQ(iwp, measurement_style):
"""Measurement model via IWP system matrices."""
H, SQ = iwp.preconditioned_discretize_1d

if measurement_style == "full":
return H, SQ
return H[:1], SQ[:1, :1]


@pytest.fixture
def SC(iwp):
"""Initial covariance via IWP process noise."""
return iwp.preconditioned_discretize_1d[1]


# dummy cholesky factor
some_chol1 = process_noise_cholesky.copy()
some_chol2 = process_noise_cholesky.copy()
@pytest.fixture
def batch_size():
"""Batch size > 1. Test batched transitions."""
return 5


@pytest.mark.parametrize("measurement_style", ["full", "partial"])
def test_propagate_cholesky_factor(H_and_SQ, SC, measurement_style):
"""Assert that sqrt propagation coincides with non-sqrt propagation."""
H, SQ = H_and_SQ

# First test: Non-optional S2
chol = tornado.sqrt.propagate_cholesky_factor(
S1=(transition_matrix @ some_chol1), S2=process_noise_cholesky
)
cov = (
transition_matrix @ some_chol1 @ some_chol1.T @ transition_matrix.T
+ process_noise_cholesky @ process_noise_cholesky.T
)
chol = tornado.sqrt.propagate_cholesky_factor(S1=(H @ SC), S2=SQ)
cov = H @ SC @ SC.T @ H.T + SQ @ SQ.T
assert jnp.allclose(chol @ chol.T, cov)
assert jnp.allclose(jnp.linalg.cholesky(cov), chol)
assert jnp.all(jnp.diag(chol) > 0)
assert jnp.allclose(jnp.tril(chol), chol)

# Second test: Optional S2
chol = tornado.sqrt.propagate_cholesky_factor(S1=(transition_matrix @ some_chol2))
cov = transition_matrix @ some_chol2 @ some_chol2.T @ transition_matrix.T

# Relax tolerance because ill-conditioned...
assert jnp.allclose(chol @ chol.T, cov)
assert jnp.allclose(jnp.linalg.cholesky(cov), chol, rtol=1e-4, atol=1e-4)
assert jnp.all(jnp.diag(chol) > 0)


def test_tril_to_positive_tril():
"""Assert that the weird sign(0)=0 behaviour is made up for."""
matrix = jnp.array(
[
[1.0, 0.0, 0.0],
[-1.0, 0.0, 0.0],
[1.0, 2.0, 3.0],
]
@pytest.mark.parametrize("measurement_style", ["full", "partial"])
def test_batched_propagate_cholesky_factors(
H_and_SQ, SC, measurement_style, batch_size
):
"""Batched propagation coincides with non-batched propagation."""

H, SQ = H_and_SQ
H = tornado.linops.BlockDiagonal(jnp.stack([H] * batch_size))
SQ = tornado.linops.BlockDiagonal(jnp.stack([SQ] * batch_size))
SC = tornado.linops.BlockDiagonal(jnp.stack([SC] * batch_size))

chol = tornado.sqrt.batched_propagate_cholesky_factor(
(H @ SC).array_stack, SQ.array_stack
)
result = tornado.sqrt.tril_to_positive_tril(matrix)
assert jnp.allclose(matrix, result)
chol_as_bd = tornado.linops.BlockDiagonal(chol)
reference = tornado.sqrt.propagate_cholesky_factor((H @ SC).todense(), SQ.todense())
assert jnp.allclose(chol_as_bd.todense(), reference)


@pytest.mark.parametrize("measurement_style", ["full", "partial"])
def test_batched_sqrtm_to_cholesky(H_and_SQ, SC, measurement_style, batch_size):
"""Sqrtm-to-cholesky is the same for batched and non-batched."""
H, SQ = H_and_SQ
d = H.shape[0]
H = tornado.linops.BlockDiagonal(jnp.stack([H] * batch_size))
SC = tornado.linops.BlockDiagonal(jnp.stack([SC] * batch_size))

chol = tornado.sqrt.batched_sqrtm_to_cholesky((H @ SC).T.array_stack)
chol_as_bd = tornado.linops.BlockDiagonal(chol)

reference = tornado.sqrt.sqrtm_to_cholesky((H @ SC).T.todense())
assert jnp.allclose(chol_as_bd.todense(), reference)
assert chol_as_bd.array_stack.shape == (batch_size, d, d)


def test_update_sqrt(iwp):
"""Test the square-root updates."""
# Use sqrt(Q) as a dummy for a sqrt(C)
A, SC = iwp.preconditioned_discretize_1d
@pytest.mark.parametrize("measurement_style", ["full", "partial"])
def test_update_sqrt(H_and_SQ, SC, measurement_style):
"""Sqrt-update coincides with non-square-root update."""

# Check square and non-square!
for H in [A, A[:1]]:
SC_new, kalman_gain, innov_chol = tornado.sqrt.update_sqrt(H, SC)
H, _ = H_and_SQ

# expected:
S = H @ SC @ SC.T @ H.T
K = SC @ SC.T @ H.T @ jnp.linalg.inv(S)
C = SC @ SC.T - K @ S @ K.T
SC_new, kalman_gain, innov_chol = tornado.sqrt.update_sqrt(H, SC)
assert isinstance(SC_new, jnp.ndarray)
assert isinstance(kalman_gain, jnp.ndarray)
assert isinstance(innov_chol, jnp.ndarray)
assert SC_new.shape == SC.shape
assert kalman_gain.shape == (H.shape[1], H.shape[0])
assert innov_chol.shape == (H.shape[0], H.shape[0])

# Test SC
assert jnp.allclose(SC_new @ SC_new.T, C)
assert jnp.allclose(SC_new, jnp.tril(SC_new))
assert jnp.all(jnp.diag(SC_new) >= 0)
# expected:
S = H @ SC @ SC.T @ H.T
K = SC @ SC.T @ H.T @ jnp.linalg.inv(S)
C = SC @ SC.T - K @ S @ K.T

# Test K
assert jnp.allclose(K, kalman_gain)
# Test SC
assert jnp.allclose(SC_new @ SC_new.T, C)
assert jnp.allclose(SC_new, jnp.tril(SC_new))

# Test S
assert jnp.allclose(innov_chol @ innov_chol.T, S)
assert jnp.allclose(innov_chol, jnp.tril(innov_chol))
assert jnp.all(jnp.diag(innov_chol) >= 0)
# Test K
assert jnp.allclose(K, kalman_gain)

# Test S
assert jnp.allclose(innov_chol @ innov_chol.T, S)
assert jnp.allclose(innov_chol, jnp.tril(innov_chol))


@pytest.mark.parametrize("measurement_style", ["full", "partial"])
def test_batched_update_sqrt(H_and_SQ, SC, measurement_style, batch_size):
"""Batched updated coincides with non-batched update."""
H, _ = H_and_SQ
d_out, d_in = H.shape
H = tornado.linops.BlockDiagonal(jnp.stack([H] * batch_size))
SC = tornado.linops.BlockDiagonal(jnp.stack([SC] * batch_size))

chol, K, S = tornado.sqrt.batched_update_sqrt(
H.array_stack,
SC.array_stack,
)
assert isinstance(chol, jnp.ndarray)
assert isinstance(K, jnp.ndarray)
assert isinstance(S, jnp.ndarray)
assert K.shape == (batch_size, d_in, d_out)
assert chol.shape == (batch_size, d_in, d_in)
assert S.shape == (batch_size, d_out, d_out)

ref_chol, ref_K, ref_S = tornado.sqrt.update_sqrt(H.todense(), SC.todense())
chol_as_bd = tornado.linops.BlockDiagonal(chol)
K_as_bd = tornado.linops.BlockDiagonal(K)
S_as_bd = tornado.linops.BlockDiagonal(S)

# K can be compared elementwise, S and chol not (see below).
assert jnp.allclose(K_as_bd.todense(), ref_K)

# The Cholesky-factor of positive semi-definite matrices is only unique
# up to column operations (e.g. column reordering), i.e. there could be slightly
# different Cholesky factors in batched and non-batched versions.
# Therefore, we only check that the results are valid Cholesky factors themselves
assert jnp.allclose((S_as_bd @ S_as_bd.T).todense(), ref_S @ ref_S.T)
assert jnp.allclose((chol_as_bd @ chol_as_bd.T).todense(), ref_chol @ ref_chol.T)
Loading