Skip to content

Commit 4ce2ac0

Browse files
Added tests
1 parent 03ccd5e commit 4ce2ac0

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

tests/test_iwp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ def test_preconditioned_system_matrices(dt, iwp):
6262
assert jnp.allclose(precond @ precond_proc_noice_chol, non_precond_proc_noice_chol)
6363

6464

65+
def test_projection_matrices(iwp):
66+
P = iwp.make_projection_matrix(0)
67+
assert isinstance(P, jnp.ndarray)
68+
d, q = iwp.wiener_process_dimension, iwp.num_derivatives
69+
assert P.shape == (d, q + 1)
70+
assert (P == 1).sum() == d
71+
72+
6573
def test_reorder_states():
6674
# Transition handles reordering
6775
iwp = tornado.iwp.IntegratedWienerTransition(

tornado/iwp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def non_preconditioned_discretize(self, dt):
123123

124124
def make_projection_matrix(self, derivative_to_project_onto):
125125
"""Creates a projection matrix kron(I_d, e_p)"""
126-
d, q = self.num_derivatives, self.wiener_process_dimension
126+
d, q = self.wiener_process_dimension, self.num_derivatives
127127
I_d = jnp.eye(d)
128128
e_p = jnp.eye(1, q + 1, derivative_to_project_onto)
129129
return jnp.kron(I_d, e_p)

0 commit comments

Comments
 (0)