Skip to content

Commit acb88aa

Browse files
committed
Impl SVDDC_ for c32/c64
1 parent 797b71b commit acb88aa

File tree

1 file changed

+24
-56
lines changed

1 file changed

+24
-56
lines changed

lax/src/svddc.rs

Lines changed: 24 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,14 @@ pub trait SVDDC_: Scalar {
2121
unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
2222
}
2323

24-
macro_rules! impl_svddc_real {
25-
($scalar:ty, $gesdd:path) => {
24+
macro_rules! impl_svddc {
25+
(@real, $scalar:ty, $gesdd:path) => {
26+
impl_svddc!(@body, $scalar, $gesdd, );
27+
};
28+
(@complex, $scalar:ty, $gesdd:path) => {
29+
impl_svddc!(@body, $scalar, $gesdd, rwork);
30+
};
31+
(@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => {
2632
impl SVDDC_ for $scalar {
2733
unsafe fn svddc(
2834
l: MatrixLayout,
@@ -50,6 +56,16 @@ macro_rules! impl_svddc_real {
5056
UVTFlag::None => (None, None),
5157
};
5258

59+
$( // for complex only
60+
let mx = n.max(m) as usize;
61+
let mn = n.min(m) as usize;
62+
let lrwork = match jobz {
63+
UVTFlag::None => 7 * mn,
64+
_ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn),
65+
};
66+
let mut $rwork_ident = vec![Self::Real::zero(); lrwork];
67+
)*
68+
5369
// eval work size
5470
let mut info = 0;
5571
let mut iwork = vec![0; 8 * k as usize];
@@ -67,6 +83,7 @@ macro_rules! impl_svddc_real {
6783
vt_row,
6884
&mut work_size,
6985
-1,
86+
$(&mut $rwork_ident,)*
7087
&mut iwork,
7188
&mut info,
7289
);
@@ -88,6 +105,7 @@ macro_rules! impl_svddc_real {
88105
vt_row,
89106
&mut work,
90107
lwork as i32,
108+
$(&mut $rwork_ident,)*
91109
&mut iwork,
92110
&mut info,
93111
);
@@ -102,57 +120,7 @@ macro_rules! impl_svddc_real {
102120
};
103121
}
104122

105-
impl_svddc_real!(f32, lapack::sgesdd);
106-
impl_svddc_real!(f64, lapack::dgesdd);
107-
108-
macro_rules! impl_svddc_complex {
109-
($scalar:ty, $gesdd:path) => {
110-
impl SVDDC_ for $scalar {
111-
unsafe fn svddc(
112-
l: MatrixLayout,
113-
jobz: UVTFlag,
114-
mut a: &mut [Self],
115-
) -> Result<SVDOutput<Self>> {
116-
let (m, n) = l.size();
117-
let k = m.min(n);
118-
let lda = l.lda();
119-
let (ucol, vtrow) = match jobz {
120-
UVTFlag::Full => (m, n),
121-
UVTFlag::Some => (k, k),
122-
UVTFlag::None => (1, 1),
123-
};
124-
let mut s = vec![Self::Real::zero(); k.max(1) as usize];
125-
let mut u = vec![Self::zero(); (m * ucol).max(1) as usize];
126-
let ldu = l.resized(m, ucol).lda();
127-
let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize];
128-
let ldvt = l.resized(vtrow, n).lda();
129-
$gesdd(
130-
l.lapacke_layout(),
131-
jobz as u8,
132-
m,
133-
n,
134-
&mut a,
135-
lda,
136-
&mut s,
137-
&mut u,
138-
ldu,
139-
&mut vt,
140-
ldvt,
141-
)
142-
.as_lapack_result()?;
143-
Ok(SVDOutput {
144-
s,
145-
u: if jobz == UVTFlag::None { None } else { Some(u) },
146-
vt: if jobz == UVTFlag::None {
147-
None
148-
} else {
149-
Some(vt)
150-
},
151-
})
152-
}
153-
}
154-
};
155-
}
156-
157-
impl_svddc_complex!(c32, lapacke::cgesdd);
158-
impl_svddc_complex!(c64, lapacke::zgesdd);
123+
impl_svddc!(@real, f32, lapack::sgesdd);
124+
impl_svddc!(@real, f64, lapack::dgesdd);
125+
impl_svddc!(@complex, c32, lapack::cgesdd);
126+
impl_svddc!(@complex, c64, lapack::zgesdd);

0 commit comments

Comments
 (0)