Skip to content

Commit fb9603d

Browse files
committed
Add more tests for Solve
1 parent 6de9acc commit fb9603d

File tree

1 file changed

+174
-28
lines changed

1 file changed

+174
-28
lines changed

ndarray-linalg/tests/solve.rs

Lines changed: 174 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,188 @@
1-
use ndarray::*;
2-
use ndarray_linalg::*;
1+
use ndarray::prelude::*;
2+
use ndarray_linalg::{
3+
assert_aclose, assert_close_l2, c32, c64, random, random_hpd, solve::*, OperationNorm, Scalar,
4+
};
5+
6+
macro_rules! test_solve {
7+
(
8+
[$($elem_type:ty => $rtol:expr),*],
9+
$a_ident:ident = $a:expr,
10+
$x_ident:ident = $x:expr,
11+
b = $b:expr,
12+
$solve:ident,
13+
) => {
14+
$({
15+
let $a_ident: Array2<$elem_type> = $a;
16+
let $x_ident: Array1<$elem_type> = $x;
17+
let b: Array1<$elem_type> = $b;
18+
let a = $a_ident;
19+
let x = $x_ident;
20+
let rtol = $rtol;
21+
assert_close_l2!(&a.$solve(&b).unwrap(), &x, rtol);
22+
assert_close_l2!(&a.factorize().unwrap().$solve(&b).unwrap(), &x, rtol);
23+
assert_close_l2!(&a.factorize_into().unwrap().$solve(&b).unwrap(), &x, rtol);
24+
})*
25+
};
26+
}
27+
28+
macro_rules! test_solve_into {
29+
(
30+
[$($elem_type:ty => $rtol:expr),*],
31+
$a_ident:ident = $a:expr,
32+
$x_ident:ident = $x:expr,
33+
b = $b:expr,
34+
$solve_into:ident,
35+
) => {
36+
$({
37+
let $a_ident: Array2<$elem_type> = $a;
38+
let $x_ident: Array1<$elem_type> = $x;
39+
let b: Array1<$elem_type> = $b;
40+
let a = $a_ident;
41+
let x = $x_ident;
42+
let rtol = $rtol;
43+
assert_close_l2!(&a.$solve_into(b.clone()).unwrap(), &x, rtol);
44+
assert_close_l2!(&a.factorize().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol);
45+
assert_close_l2!(&a.factorize_into().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol);
46+
})*
47+
};
48+
}
49+
50+
macro_rules! test_solve_inplace {
51+
(
52+
[$($elem_type:ty => $rtol:expr),*],
53+
$a_ident:ident = $a:expr,
54+
$x_ident:ident = $x:expr,
55+
b = $b:expr,
56+
$solve_inplace:ident,
57+
) => {
58+
$({
59+
let $a_ident: Array2<$elem_type> = $a;
60+
let $x_ident: Array1<$elem_type> = $x;
61+
let b: Array1<$elem_type> = $b;
62+
let a = $a_ident;
63+
let x = $x_ident;
64+
let rtol = $rtol;
65+
{
66+
let mut b = b.clone();
67+
assert_close_l2!(&a.$solve_inplace(&mut b).unwrap(), &x, rtol);
68+
assert_close_l2!(&b, &x, rtol);
69+
}
70+
{
71+
let mut b = b.clone();
72+
assert_close_l2!(&a.factorize().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol);
73+
assert_close_l2!(&b, &x, rtol);
74+
}
75+
{
76+
let mut b = b.clone();
77+
assert_close_l2!(&a.factorize_into().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol);
78+
assert_close_l2!(&b, &x, rtol);
79+
}
80+
})*
81+
};
82+
}
83+
84+
macro_rules! test_solve_all {
85+
(
86+
[$($elem_type:ty => $rtol:expr),*],
87+
$a_ident:ident = $a:expr,
88+
$x_ident:ident = $x:expr,
89+
b = $b:expr,
90+
[$solve:ident, $solve_into:ident, $solve_inplace:ident],
91+
) => {
92+
test_solve!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve,);
93+
test_solve_into!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_into,);
94+
test_solve_inplace!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_inplace,);
95+
};
96+
}
97+
98+
#[test]
99+
fn solve_random_float() {
100+
for n in 0..=8 {
101+
for &set_f in &[false, true] {
102+
test_solve_all!(
103+
[f32 => 1e-3, f64 => 1e-9],
104+
a = random([n; 2].set_f(set_f)),
105+
x = random(n),
106+
b = a.dot(&x),
107+
[solve, solve_into, solve_inplace],
108+
);
109+
}
110+
}
111+
}
112+
113+
#[test]
114+
fn solve_random_complex() {
115+
for n in 0..=8 {
116+
for &set_f in &[false, true] {
117+
test_solve_all!(
118+
[c32 => 1e-3, c64 => 1e-9],
119+
a = random([n; 2].set_f(set_f)),
120+
x = random(n),
121+
b = a.dot(&x),
122+
[solve, solve_into, solve_inplace],
123+
);
124+
}
125+
}
126+
}
3127

4128
#[test]
5-
fn solve_random() {
6-
let a: Array2<f64> = random((3, 3));
7-
let x: Array1<f64> = random(3);
8-
let b = a.dot(&x);
9-
let y = a.solve_into(b).unwrap();
10-
assert_close_l2!(&x, &y, 1e-7);
129+
fn solve_t_random_float() {
130+
for n in 0..=8 {
131+
for &set_f in &[false, true] {
132+
test_solve_all!(
133+
[f32 => 1e-3, f64 => 1e-9],
134+
a = random([n; 2].set_f(set_f)),
135+
x = random(n),
136+
b = a.t().dot(&x),
137+
[solve_t, solve_t_into, solve_t_inplace],
138+
);
139+
}
140+
}
11141
}
12142

13143
#[test]
14-
fn solve_random_t() {
15-
let a: Array2<f64> = random((3, 3).f());
16-
let x: Array1<f64> = random(3);
17-
let b = a.dot(&x);
18-
let y = a.solve_into(b).unwrap();
19-
assert_close_l2!(&x, &y, 1e-7);
144+
fn solve_t_random_complex() {
145+
for n in 0..=8 {
146+
for &set_f in &[false, true] {
147+
test_solve_all!(
148+
[c32 => 1e-3, c64 => 1e-9],
149+
a = random([n; 2].set_f(set_f)),
150+
x = random(n),
151+
b = a.t().dot(&x),
152+
[solve_t, solve_t_into, solve_t_inplace],
153+
);
154+
}
155+
}
20156
}
21157

22158
#[test]
23-
fn solve_factorized() {
24-
let a: Array2<f64> = random((3, 3));
25-
let ans: Array1<f64> = random(3);
26-
let b = a.dot(&ans);
27-
let f = a.factorize_into().unwrap();
28-
let x = f.solve_into(b).unwrap();
29-
assert_close_l2!(&x, &ans, 1e-7);
159+
fn solve_h_random_float() {
160+
for n in 0..=8 {
161+
for &set_f in &[false, true] {
162+
test_solve_all!(
163+
[f32 => 1e-3, f64 => 1e-9],
164+
a = random([n; 2].set_f(set_f)),
165+
x = random(n),
166+
b = a.t().mapv(|x| x.conj()).dot(&x),
167+
[solve_h, solve_h_into, solve_h_inplace],
168+
);
169+
}
170+
}
30171
}
31172

32173
#[test]
33-
fn solve_factorized_t() {
34-
let a: Array2<f64> = random((3, 3).f());
35-
let ans: Array1<f64> = random(3);
36-
let b = a.dot(&ans);
37-
let f = a.factorize_into().unwrap();
38-
let x = f.solve_into(b).unwrap();
39-
assert_close_l2!(&x, &ans, 1e-7);
174+
fn solve_h_random_complex() {
175+
for n in 0..=8 {
176+
for &set_f in &[false, true] {
177+
test_solve_all!(
178+
[c32 => 1e-3, c64 => 1e-9],
179+
a = random([n; 2].set_f(set_f)),
180+
x = random(n),
181+
b = a.t().mapv(|x| x.conj()).dot(&x),
182+
[solve_h, solve_h_into, solve_h_inplace],
183+
);
184+
}
185+
}
40186
}
41187

42188
#[test]

0 commit comments

Comments
 (0)