11use super :: * ;
22use crate :: { error:: * , layout:: MatrixLayout } ;
33use cauchy:: * ;
4- use num_traits:: Zero ;
4+ use num_traits:: { ToPrimitive , Zero } ;
55
66/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
77///
@@ -21,56 +21,106 @@ pub trait SVDDC_: Scalar {
2121 unsafe fn svddc ( l : MatrixLayout , jobz : UVTFlag , a : & mut [ Self ] ) -> Result < SVDOutput < Self > > ;
2222}
2323
24- macro_rules! impl_svdd {
25- ( $scalar: ty, $gesdd: path) => {
24+ macro_rules! impl_svddc {
25+ ( @real, $scalar: ty, $gesdd: path) => {
26+ impl_svddc!( @body, $scalar, $gesdd, ) ;
27+ } ;
28+ ( @complex, $scalar: ty, $gesdd: path) => {
29+ impl_svddc!( @body, $scalar, $gesdd, rwork) ;
30+ } ;
31+ ( @body, $scalar: ty, $gesdd: path, $( $rwork_ident: ident) ,* ) => {
2632 impl SVDDC_ for $scalar {
2733 unsafe fn svddc(
2834 l: MatrixLayout ,
2935 jobz: UVTFlag ,
3036 mut a: & mut [ Self ] ,
3137 ) -> Result <SVDOutput <Self >> {
32- let ( m, n) = l. size( ) ;
38+ let m = l. lda( ) ;
39+ let n = l. len( ) ;
3340 let k = m. min( n) ;
34- let lda = l. lda( ) ;
35- let ( ucol, vtrow) = match jobz {
36- UVTFlag :: Full => ( m, n) ,
41+ let mut s = vec![ Self :: Real :: zero( ) ; k as usize ] ;
42+
43+ let ( u_col, vt_row) = match jobz {
44+ UVTFlag :: Full | UVTFlag :: None => ( m, n) ,
3745 UVTFlag :: Some => ( k, k) ,
38- UVTFlag :: None => ( 1 , 1 ) ,
3946 } ;
40- let mut s = vec![ Self :: Real :: zero( ) ; k. max( 1 ) as usize ] ;
41- let mut u = vec![ Self :: zero( ) ; ( m * ucol) . max( 1 ) as usize ] ;
42- let ldu = l. resized( m, ucol) . lda( ) ;
43- let mut vt = vec![ Self :: zero( ) ; ( vtrow * n) . max( 1 ) as usize ] ;
44- let ldvt = l. resized( vtrow, n) . lda( ) ;
47+ let ( mut u, mut vt) = match jobz {
48+ UVTFlag :: Full => (
49+ Some ( vec![ Self :: zero( ) ; ( m * m) as usize ] ) ,
50+ Some ( vec![ Self :: zero( ) ; ( n * n) as usize ] ) ,
51+ ) ,
52+ UVTFlag :: Some => (
53+ Some ( vec![ Self :: zero( ) ; ( m * u_col) as usize ] ) ,
54+ Some ( vec![ Self :: zero( ) ; ( n * vt_row) as usize ] ) ,
55+ ) ,
56+ UVTFlag :: None => ( None , None ) ,
57+ } ;
58+
59+ $( // for complex only
60+ let mx = n. max( m) as usize ;
61+ let mn = n. min( m) as usize ;
62+ let lrwork = match jobz {
63+ UVTFlag :: None => 7 * mn,
64+ _ => std:: cmp:: max( 5 * mn* mn + 5 * mn, 2 * mx* mn + 2 * mn* mn + mn) ,
65+ } ;
66+ let mut $rwork_ident = vec![ Self :: Real :: zero( ) ; lrwork] ;
67+ ) *
68+
69+ // eval work size
70+ let mut info = 0 ;
71+ let mut iwork = vec![ 0 ; 8 * k as usize ] ;
72+ let mut work_size = [ Self :: zero( ) ] ;
4573 $gesdd(
46- l. lapacke_layout( ) ,
4774 jobz as u8 ,
4875 m,
4976 n,
5077 & mut a,
51- lda ,
78+ m ,
5279 & mut s,
53- & mut u,
54- ldu,
55- & mut vt,
56- ldvt,
57- )
58- . as_lapack_result( ) ?;
59- Ok ( SVDOutput {
60- s,
61- u: if jobz == UVTFlag :: None { None } else { Some ( u) } ,
62- vt: if jobz == UVTFlag :: None {
63- None
64- } else {
65- Some ( vt)
66- } ,
67- } )
80+ u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
81+ m,
82+ vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
83+ vt_row,
84+ & mut work_size,
85+ -1 ,
86+ $( & mut $rwork_ident, ) *
87+ & mut iwork,
88+ & mut info,
89+ ) ;
90+ info. as_lapack_result( ) ?;
91+
92+ // do svd
93+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
94+ let mut work = vec![ Self :: zero( ) ; lwork] ;
95+ $gesdd(
96+ jobz as u8 ,
97+ m,
98+ n,
99+ & mut a,
100+ m,
101+ & mut s,
102+ u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
103+ m,
104+ vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
105+ vt_row,
106+ & mut work,
107+ lwork as i32 ,
108+ $( & mut $rwork_ident, ) *
109+ & mut iwork,
110+ & mut info,
111+ ) ;
112+ info. as_lapack_result( ) ?;
113+
114+ match l {
115+ MatrixLayout :: F { .. } => Ok ( SVDOutput { s, u, vt } ) ,
116+ MatrixLayout :: C { .. } => Ok ( SVDOutput { s, u: vt, vt: u } ) ,
117+ }
68118 }
69119 }
70120 } ;
71121}
72122
73- impl_svdd ! ( f32 , lapacke :: sgesdd) ;
74- impl_svdd ! ( f64 , lapacke :: dgesdd) ;
75- impl_svdd ! ( c32, lapacke :: cgesdd) ;
76- impl_svdd ! ( c64, lapacke :: zgesdd) ;
123+ impl_svddc ! ( @real , f32 , lapack :: sgesdd) ;
124+ impl_svddc ! ( @real , f64 , lapack :: dgesdd) ;
125+ impl_svddc ! ( @complex , c32, lapack :: cgesdd) ;
126+ impl_svddc ! ( @complex , c64, lapack :: zgesdd) ;
0 commit comments