Skip to content

Commit 8c1b069

Browse files
authored
Merge pull request #219 from rust-ndarray/lapack-svddc
SVD divide-and-conquer by LAPACK
2 parents 7e0f582 + cf56af5 commit 8c1b069

File tree

3 files changed

+118
-65
lines changed

3 files changed

+118
-65
lines changed

lax/src/svddc.rs

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::*;
22
use crate::{error::*, layout::MatrixLayout};
33
use cauchy::*;
4-
use num_traits::Zero;
4+
use num_traits::{ToPrimitive, Zero};
55

66
/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
77
///
@@ -21,56 +21,106 @@ pub trait SVDDC_: Scalar {
2121
unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
2222
}
2323

24-
macro_rules! impl_svdd {
25-
($scalar:ty, $gesdd:path) => {
24+
macro_rules! impl_svddc {
25+
(@real, $scalar:ty, $gesdd:path) => {
26+
impl_svddc!(@body, $scalar, $gesdd, );
27+
};
28+
(@complex, $scalar:ty, $gesdd:path) => {
29+
impl_svddc!(@body, $scalar, $gesdd, rwork);
30+
};
31+
(@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => {
2632
impl SVDDC_ for $scalar {
2733
unsafe fn svddc(
2834
l: MatrixLayout,
2935
jobz: UVTFlag,
3036
mut a: &mut [Self],
3137
) -> Result<SVDOutput<Self>> {
32-
let (m, n) = l.size();
38+
let m = l.lda();
39+
let n = l.len();
3340
let k = m.min(n);
34-
let lda = l.lda();
35-
let (ucol, vtrow) = match jobz {
36-
UVTFlag::Full => (m, n),
41+
let mut s = vec![Self::Real::zero(); k as usize];
42+
43+
let (u_col, vt_row) = match jobz {
44+
UVTFlag::Full | UVTFlag::None => (m, n),
3745
UVTFlag::Some => (k, k),
38-
UVTFlag::None => (1, 1),
3946
};
40-
let mut s = vec![Self::Real::zero(); k.max(1) as usize];
41-
let mut u = vec![Self::zero(); (m * ucol).max(1) as usize];
42-
let ldu = l.resized(m, ucol).lda();
43-
let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize];
44-
let ldvt = l.resized(vtrow, n).lda();
47+
let (mut u, mut vt) = match jobz {
48+
UVTFlag::Full => (
49+
Some(vec![Self::zero(); (m * m) as usize]),
50+
Some(vec![Self::zero(); (n * n) as usize]),
51+
),
52+
UVTFlag::Some => (
53+
Some(vec![Self::zero(); (m * u_col) as usize]),
54+
Some(vec![Self::zero(); (n * vt_row) as usize]),
55+
),
56+
UVTFlag::None => (None, None),
57+
};
58+
59+
$( // for complex only
60+
let mx = n.max(m) as usize;
61+
let mn = n.min(m) as usize;
62+
let lrwork = match jobz {
63+
UVTFlag::None => 7 * mn,
64+
_ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn),
65+
};
66+
let mut $rwork_ident = vec![Self::Real::zero(); lrwork];
67+
)*
68+
69+
// eval work size
70+
let mut info = 0;
71+
let mut iwork = vec![0; 8 * k as usize];
72+
let mut work_size = [Self::zero()];
4573
$gesdd(
46-
l.lapacke_layout(),
4774
jobz as u8,
4875
m,
4976
n,
5077
&mut a,
51-
lda,
78+
m,
5279
&mut s,
53-
&mut u,
54-
ldu,
55-
&mut vt,
56-
ldvt,
57-
)
58-
.as_lapack_result()?;
59-
Ok(SVDOutput {
60-
s,
61-
u: if jobz == UVTFlag::None { None } else { Some(u) },
62-
vt: if jobz == UVTFlag::None {
63-
None
64-
} else {
65-
Some(vt)
66-
},
67-
})
80+
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
81+
m,
82+
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
83+
vt_row,
84+
&mut work_size,
85+
-1,
86+
$(&mut $rwork_ident,)*
87+
&mut iwork,
88+
&mut info,
89+
);
90+
info.as_lapack_result()?;
91+
92+
// do svd
93+
let lwork = work_size[0].to_usize().unwrap();
94+
let mut work = vec![Self::zero(); lwork];
95+
$gesdd(
96+
jobz as u8,
97+
m,
98+
n,
99+
&mut a,
100+
m,
101+
&mut s,
102+
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
103+
m,
104+
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
105+
vt_row,
106+
&mut work,
107+
lwork as i32,
108+
$(&mut $rwork_ident,)*
109+
&mut iwork,
110+
&mut info,
111+
);
112+
info.as_lapack_result()?;
113+
114+
match l {
115+
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
116+
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
117+
}
68118
}
69119
}
70120
};
71121
}
72122

73-
impl_svdd!(f32, lapacke::sgesdd);
74-
impl_svdd!(f64, lapacke::dgesdd);
75-
impl_svdd!(c32, lapacke::cgesdd);
76-
impl_svdd!(c64, lapacke::zgesdd);
123+
impl_svddc!(@real, f32, lapack::sgesdd);
124+
impl_svddc!(@real, f64, lapack::dgesdd);
125+
impl_svddc!(@complex, c32, lapack::cgesdd);
126+
impl_svddc!(@complex, c64, lapack::zgesdd);

ndarray-linalg/src/svddc.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
//! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd)
22
3+
use super::{convert::*, error::*, layout::*, types::*};
34
use ndarray::*;
45

5-
use super::convert::*;
6-
use super::error::*;
7-
use super::layout::*;
8-
use super::types::*;
9-
106
pub use lapack::svddc::UVTFlag;
117

128
/// Singular-value decomposition of matrix (copying) by divide-and-conquer
@@ -87,17 +83,21 @@ where
8783
let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? };
8884
let (m, n) = l.size();
8985
let k = m.min(n);
90-
let (ldu, tdu, ldvt, tdvt) = match uvt_flag {
91-
UVTFlag::Full => (m, m, n, n),
92-
UVTFlag::Some => (m, k, k, n),
93-
UVTFlag::None => (1, 1, 1, 1),
86+
87+
let (u_col, vt_row) = match uvt_flag {
88+
UVTFlag::Full => (m, n),
89+
UVTFlag::Some => (k, k),
90+
UVTFlag::None => (0, 0),
9491
};
92+
9593
let u = svd_res
9694
.u
97-
.map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches"));
95+
.map(|u| into_matrix(l.resized(m, u_col), u).unwrap());
96+
9897
let vt = svd_res
9998
.vt
100-
.map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches"));
99+
.map(|vt| into_matrix(l.resized(vt_row, n), vt).unwrap());
100+
101101
let s = ArrayBase::from(svd_res.s);
102102
Ok((u, s, vt))
103103
}

ndarray-linalg/tests/svddc.rs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use ndarray::*;
22
use ndarray_linalg::*;
33

4-
fn test(a: &Array2<f64>, flag: UVTFlag) {
4+
fn test<T: Scalar + Lapack>(a: &Array2<T>, flag: UVTFlag) {
55
let (n, m) = a.dim();
66
let k = n.min(m);
77
let answer = a.clone();
88
println!("a = \n{:?}", a);
99
let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap();
10-
let mut sm = match flag {
10+
let mut sm: Array2<T> = match flag {
1111
UVTFlag::Full => Array::zeros((n, m)),
1212
UVTFlag::Some => Array::zeros((k, k)),
1313
UVTFlag::None => {
@@ -22,53 +22,56 @@ fn test(a: &Array2<f64>, flag: UVTFlag) {
2222
println!("s = \n{:?}", &s);
2323
println!("v = \n{:?}", &vt);
2424
for i in 0..k {
25-
sm[(i, i)] = s[i];
25+
sm[(i, i)] = T::from_real(s[i]);
2626
}
27-
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
27+
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7));
2828
}
2929

3030
macro_rules! test_svd_impl {
31-
($n:expr, $m:expr) => {
31+
($scalar:ty, $n:expr, $m:expr) => {
3232
paste::item! {
3333
#[test]
34-
fn [<svddc_full_ $n x $m>]() {
34+
fn [<svddc_ $scalar _full_ $n x $m>]() {
3535
let a = random(($n, $m));
36-
test(&a, UVTFlag::Full);
36+
test::<$scalar>(&a, UVTFlag::Full);
3737
}
3838

3939
#[test]
40-
fn [<svddc_some_ $n x $m>]() {
40+
fn [<svddc_ $scalar _some_ $n x $m>]() {
4141
let a = random(($n, $m));
42-
test(&a, UVTFlag::Some);
42+
test::<$scalar>(&a, UVTFlag::Some);
4343
}
4444

4545
#[test]
46-
fn [<svddc_none_ $n x $m>]() {
46+
fn [<svddc_ $scalar _none_ $n x $m>]() {
4747
let a = random(($n, $m));
48-
test(&a, UVTFlag::None);
48+
test::<$scalar>(&a, UVTFlag::None);
4949
}
5050

5151
#[test]
52-
fn [<svddc_full_ $n x $m _t>]() {
52+
fn [<svddc_ $scalar _full_ $n x $m _t>]() {
5353
let a = random(($n, $m).f());
54-
test(&a, UVTFlag::Full);
54+
test::<$scalar>(&a, UVTFlag::Full);
5555
}
5656

5757
#[test]
58-
fn [<svddc_some_ $n x $m _t>]() {
58+
fn [<svddc_ $scalar _some_ $n x $m _t>]() {
5959
let a = random(($n, $m).f());
60-
test(&a, UVTFlag::Some);
60+
test::<$scalar>(&a, UVTFlag::Some);
6161
}
6262

6363
#[test]
64-
fn [<svddc_none_ $n x $m _t>]() {
64+
fn [<svddc_ $scalar _none_ $n x $m _t>]() {
6565
let a = random(($n, $m).f());
66-
test(&a, UVTFlag::None);
66+
test::<$scalar>(&a, UVTFlag::None);
6767
}
6868
}
6969
};
7070
}
7171

72-
test_svd_impl!(3, 3);
73-
test_svd_impl!(4, 3);
74-
test_svd_impl!(3, 4);
72+
test_svd_impl!(f64, 3, 3);
73+
test_svd_impl!(f64, 4, 3);
74+
test_svd_impl!(f64, 3, 4);
75+
test_svd_impl!(c64, 3, 3);
76+
test_svd_impl!(c64, 4, 3);
77+
test_svd_impl!(c64, 3, 4);

0 commit comments

Comments
 (0)