Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 166 additions & 33 deletions lax/src/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,27 @@

use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::Zero;
use num_traits::{ToPrimitive, Zero};

#[repr(u8)]
#[derive(Debug, Copy, Clone)]
enum FlagSVD {
All = b'A',
// OverWrite = b'O',
// Separately = b'S',
No = b'N',
}

impl FlagSVD {
fn from_bool(calc_uv: bool) -> Self {
if calc_uv {
FlagSVD::All
} else {
FlagSVD::No
}
}
}

/// Result of SVD
pub struct SVDOutput<A: Scalar> {
/// diagonal values
Expand All @@ -24,6 +35,7 @@ pub struct SVDOutput<A: Scalar> {

/// Wraps `*gesvd`
pub trait SVD_: Scalar {
/// Calculate singular value decomposition $ A = U \Sigma V^T $
unsafe fn svd(
l: MatrixLayout,
calc_u: bool,
Expand All @@ -32,7 +44,7 @@ pub trait SVD_: Scalar {
) -> Result<SVDOutput<Self>>;
}

macro_rules! impl_svd {
macro_rules! impl_svd_real {
($scalar:ty, $gesvd:path) => {
impl SVD_ for $scalar {
unsafe fn svd(
Expand All @@ -41,48 +53,169 @@ macro_rules! impl_svd {
calc_vt: bool,
mut a: &mut [Self],
) -> Result<SVDOutput<Self>> {
let (m, n) = l.size();
let k = ::std::cmp::min(n, m);
let lda = l.lda();
let (ju, ldu, mut u) = if calc_u {
(FlagSVD::All, m, vec![Self::zero(); (m * m) as usize])
} else {
(FlagSVD::No, 1, Vec::new())
let ju = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
};
let (jvt, ldvt, mut vt) = if calc_vt {
(FlagSVD::All, n, vec![Self::zero(); (n * n) as usize])
} else {
(FlagSVD::No, n, Vec::new())
let jvt = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
};

let m = l.lda();
let mut u = match ju {
FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]),
FlagSVD::No => None,
};

let n = l.len();
let mut vt = match jvt {
FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]),
FlagSVD::No => None,
};

let k = std::cmp::min(m, n);
let mut s = vec![Self::Real::zero(); k as usize];
let mut superb = vec![Self::Real::zero(); (k - 1) as usize];

// eval work size
let mut info = 0;
let mut work_size = [Self::zero()];
$gesvd(
ju as u8,
jvt as u8,
m,
n,
&mut a,
m,
&mut s,
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
m,
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work_size,
-1,
&mut info,
);
info.as_lapack_result()?;

// calc
let lwork = work_size[0].to_usize().unwrap();
let mut work = vec![Self::zero(); lwork];
$gesvd(
l.lapacke_layout(),
ju as u8,
jvt as u8,
m,
n,
&mut a,
lda,
m,
&mut s,
&mut u,
ldu,
&mut vt,
ldvt,
&mut superb,
)
.as_lapack_result()?;
Ok(SVDOutput {
s,
u: if calc_u { Some(u) } else { None },
vt: if calc_vt { Some(vt) } else { None },
})
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
m,
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work,
lwork as i32,
&mut info,
);
info.as_lapack_result()?;
match l {
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
}
}
}
};
} // impl_svd_real!

impl_svd_real!(f64, lapack::dgesvd);
impl_svd_real!(f32, lapack::sgesvd);

macro_rules! impl_svd_complex {
($scalar:ty, $gesvd:path) => {
impl SVD_ for $scalar {
unsafe fn svd(
l: MatrixLayout,
calc_u: bool,
calc_vt: bool,
mut a: &mut [Self],
) -> Result<SVDOutput<Self>> {
let ju = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
};
let jvt = match l {
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
};

let m = l.lda();
let mut u = match ju {
FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]),
FlagSVD::No => None,
};

let n = l.len();
let mut vt = match jvt {
FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]),
FlagSVD::No => None,
};

let k = std::cmp::min(m, n);
let mut s = vec![Self::Real::zero(); k as usize];

let mut rwork = vec![Self::Real::zero(); 5 * k as usize];

// eval work size
let mut info = 0;
let mut work_size = [Self::zero()];
$gesvd(
ju as u8,
jvt as u8,
m,
n,
&mut a,
m,
&mut s,
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
m,
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work_size,
-1,
&mut rwork,
&mut info,
);
info.as_lapack_result()?;

// calc
let lwork = work_size[0].to_usize().unwrap();
let mut work = vec![Self::zero(); lwork];
$gesvd(
ju as u8,
jvt as u8,
m,
n,
&mut a,
m,
&mut s,
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
m,
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
n,
&mut work,
lwork as i32,
&mut rwork,
&mut info,
);
info.as_lapack_result()?;
match l {
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
}
}
}
};
} // impl_svd!
} // impl_svd_real!

impl_svd!(f64, lapacke::dgesvd);
impl_svd!(f32, lapacke::sgesvd);
impl_svd!(c64, lapacke::zgesvd);
impl_svd!(c32, lapacke::cgesvd);
impl_svd_complex!(c64, lapack::zgesvd);
impl_svd_complex!(c32, lapack::cgesvd);
28 changes: 21 additions & 7 deletions ndarray-linalg/src/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

use ndarray::*;

use super::convert::*;
use super::error::*;
use super::layout::*;
use super::types::*;
Expand Down Expand Up @@ -99,12 +98,27 @@ where
let l = self.layout()?;
let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? };
let (n, m) = l.size();
let u = svd_res
.u
.map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches"));
let vt = svd_res
.vt
.map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches"));
let n = n as usize;
let m = m as usize;

let u = svd_res.u.map(|u| {
assert_eq!(u.len(), n * n);
match l {
MatrixLayout::F { .. } => Array::from_shape_vec((n, n).f(), u),
MatrixLayout::C { .. } => Array::from_shape_vec((n, n), u),
}
.unwrap()
});

let vt = svd_res.vt.map(|vt| {
assert_eq!(vt.len(), m * m);
match l {
MatrixLayout::F { .. } => Array::from_shape_vec((m, m).f(), vt),
MatrixLayout::C { .. } => Array::from_shape_vec((m, m), vt),
}
.unwrap()
});

let s = ArrayBase::from(svd_res.s);
Ok((u, s, vt))
}
Expand Down
60 changes: 36 additions & 24 deletions ndarray-linalg/tests/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ndarray::*;
use ndarray_linalg::*;
use std::cmp::min;

fn test(a: &Array2<f64>) {
fn test<T: Scalar + Lapack>(a: &Array2<T>) {
let (n, m) = a.dim();
let answer = a.clone();
println!("a = \n{:?}", a);
Expand All @@ -12,14 +12,14 @@ fn test(a: &Array2<f64>) {
println!("u = \n{:?}", &u);
println!("s = \n{:?}", &s);
println!("v = \n{:?}", &vt);
let mut sm = Array::zeros((n, m));
let mut sm = Array::<T, _>::zeros((n, m));
for i in 0..min(n, m) {
sm[(i, i)] = s[i];
sm[(i, i)] = T::from(s[i]).unwrap();
}
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7));
}

fn test_no_vt(a: &Array2<f64>) {
fn test_no_vt<T: Scalar + Lapack>(a: &Array2<T>) {
let (n, _m) = a.dim();
println!("a = \n{:?}", a);
let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap();
Expand All @@ -30,7 +30,7 @@ fn test_no_vt(a: &Array2<f64>) {
assert_eq!(u.dim().1, n);
}

fn test_no_u(a: &Array2<f64>) {
fn test_no_u<T: Scalar + Lapack>(a: &Array2<T>) {
let (_n, m) = a.dim();
println!("a = \n{:?}", a);
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap();
Expand All @@ -41,40 +41,52 @@ fn test_no_u(a: &Array2<f64>) {
assert_eq!(vt.dim().1, m);
}

fn test_diag_only(a: &Array2<f64>) {
fn test_diag_only<T: Scalar + Lapack>(a: &Array2<T>) {
println!("a = \n{:?}", a);
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap();
assert!(u.is_none());
assert!(vt.is_none());
}

macro_rules! test_svd_impl {
($test:ident, $n:expr, $m:expr) => {
($type:ty, $test:ident, $n:expr, $m:expr) => {
paste::item! {
#[test]
fn [<svd_ $test _ $n x $m>]() {
fn [<svd_ $type _ $test _ $n x $m>]() {
let a = random(($n, $m));
$test(&a);
$test::<$type>(&a);
}

#[test]
fn [<svd_ $test _ $n x $m _t>]() {
fn [<svd_ $type _ $test _ $n x $m _t>]() {
let a = random(($n, $m).f());
$test(&a);
$test::<$type>(&a);
}
}
};
}

test_svd_impl!(test, 3, 3);
test_svd_impl!(test_no_vt, 3, 3);
test_svd_impl!(test_no_u, 3, 3);
test_svd_impl!(test_diag_only, 3, 3);
test_svd_impl!(test, 4, 3);
test_svd_impl!(test_no_vt, 4, 3);
test_svd_impl!(test_no_u, 4, 3);
test_svd_impl!(test_diag_only, 4, 3);
test_svd_impl!(test, 3, 4);
test_svd_impl!(test_no_vt, 3, 4);
test_svd_impl!(test_no_u, 3, 4);
test_svd_impl!(test_diag_only, 3, 4);
test_svd_impl!(f64, test, 3, 3);
test_svd_impl!(f64, test_no_vt, 3, 3);
test_svd_impl!(f64, test_no_u, 3, 3);
test_svd_impl!(f64, test_diag_only, 3, 3);
test_svd_impl!(f64, test, 4, 3);
test_svd_impl!(f64, test_no_vt, 4, 3);
test_svd_impl!(f64, test_no_u, 4, 3);
test_svd_impl!(f64, test_diag_only, 4, 3);
test_svd_impl!(f64, test, 3, 4);
test_svd_impl!(f64, test_no_vt, 3, 4);
test_svd_impl!(f64, test_no_u, 3, 4);
test_svd_impl!(f64, test_diag_only, 3, 4);
test_svd_impl!(c64, test, 3, 3);
test_svd_impl!(c64, test_no_vt, 3, 3);
test_svd_impl!(c64, test_no_u, 3, 3);
test_svd_impl!(c64, test_diag_only, 3, 3);
test_svd_impl!(c64, test, 4, 3);
test_svd_impl!(c64, test_no_vt, 4, 3);
test_svd_impl!(c64, test_no_u, 4, 3);
test_svd_impl!(c64, test_diag_only, 4, 3);
test_svd_impl!(c64, test, 3, 4);
test_svd_impl!(c64, test_no_vt, 3, 4);
test_svd_impl!(c64, test_no_u, 3, 4);
test_svd_impl!(c64, test_diag_only, 3, 4);