@@ -21,8 +21,14 @@ pub trait SVDDC_: Scalar {
2121 unsafe fn svddc ( l : MatrixLayout , jobz : UVTFlag , a : & mut [ Self ] ) -> Result < SVDOutput < Self > > ;
2222}
2323
24- macro_rules! impl_svddc_real {
25- ( $scalar: ty, $gesdd: path) => {
24+ macro_rules! impl_svddc {
25+ ( @real, $scalar: ty, $gesdd: path) => {
26+ impl_svddc!( @body, $scalar, $gesdd, ) ;
27+ } ;
28+ ( @complex, $scalar: ty, $gesdd: path) => {
29+ impl_svddc!( @body, $scalar, $gesdd, rwork) ;
30+ } ;
31+ ( @body, $scalar: ty, $gesdd: path, $( $rwork_ident: ident) ,* ) => {
2632 impl SVDDC_ for $scalar {
2733 unsafe fn svddc(
2834 l: MatrixLayout ,
@@ -50,6 +56,16 @@ macro_rules! impl_svddc_real {
5056 UVTFlag :: None => ( None , None ) ,
5157 } ;
5258
59+ $( // for complex only
60+ let mx = n. max( m) as usize ;
61+ let mn = n. min( m) as usize ;
62+ let lrwork = match jobz {
63+ UVTFlag :: None => 7 * mn,
64+ _ => std:: cmp:: max( 5 * mn* mn + 5 * mn, 2 * mx* mn + 2 * mn* mn + mn) ,
65+ } ;
66+ let mut $rwork_ident = vec![ Self :: Real :: zero( ) ; lrwork] ;
67+ ) *
68+
5369 // eval work size
5470 let mut info = 0 ;
5571 let mut iwork = vec![ 0 ; 8 * k as usize ] ;
@@ -67,6 +83,7 @@ macro_rules! impl_svddc_real {
6783 vt_row,
6884 & mut work_size,
6985 -1 ,
86+ $( & mut $rwork_ident, ) *
7087 & mut iwork,
7188 & mut info,
7289 ) ;
@@ -88,6 +105,7 @@ macro_rules! impl_svddc_real {
88105 vt_row,
89106 & mut work,
90107 lwork as i32 ,
108+ $( & mut $rwork_ident, ) *
91109 & mut iwork,
92110 & mut info,
93111 ) ;
@@ -102,57 +120,7 @@ macro_rules! impl_svddc_real {
102120 } ;
103121}
104122
105- impl_svddc_real ! ( f32 , lapack:: sgesdd) ;
106- impl_svddc_real ! ( f64 , lapack:: dgesdd) ;
107-
108- macro_rules! impl_svddc_complex {
109- ( $scalar: ty, $gesdd: path) => {
110- impl SVDDC_ for $scalar {
111- unsafe fn svddc(
112- l: MatrixLayout ,
113- jobz: UVTFlag ,
114- mut a: & mut [ Self ] ,
115- ) -> Result <SVDOutput <Self >> {
116- let ( m, n) = l. size( ) ;
117- let k = m. min( n) ;
118- let lda = l. lda( ) ;
119- let ( ucol, vtrow) = match jobz {
120- UVTFlag :: Full => ( m, n) ,
121- UVTFlag :: Some => ( k, k) ,
122- UVTFlag :: None => ( 1 , 1 ) ,
123- } ;
124- let mut s = vec![ Self :: Real :: zero( ) ; k. max( 1 ) as usize ] ;
125- let mut u = vec![ Self :: zero( ) ; ( m * ucol) . max( 1 ) as usize ] ;
126- let ldu = l. resized( m, ucol) . lda( ) ;
127- let mut vt = vec![ Self :: zero( ) ; ( vtrow * n) . max( 1 ) as usize ] ;
128- let ldvt = l. resized( vtrow, n) . lda( ) ;
129- $gesdd(
130- l. lapacke_layout( ) ,
131- jobz as u8 ,
132- m,
133- n,
134- & mut a,
135- lda,
136- & mut s,
137- & mut u,
138- ldu,
139- & mut vt,
140- ldvt,
141- )
142- . as_lapack_result( ) ?;
143- Ok ( SVDOutput {
144- s,
145- u: if jobz == UVTFlag :: None { None } else { Some ( u) } ,
146- vt: if jobz == UVTFlag :: None {
147- None
148- } else {
149- Some ( vt)
150- } ,
151- } )
152- }
153- }
154- } ;
155- }
156-
157- impl_svddc_complex ! ( c32, lapacke:: cgesdd) ;
158- impl_svddc_complex ! ( c64, lapacke:: zgesdd) ;
123+ impl_svddc ! ( @real, f32 , lapack:: sgesdd) ;
124+ impl_svddc ! ( @real, f64 , lapack:: dgesdd) ;
125+ impl_svddc ! ( @complex, c32, lapack:: cgesdd) ;
126+ impl_svddc ! ( @complex, c64, lapack:: zgesdd) ;
0 commit comments