|
1 | 1 | //! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd) |
2 | 2 |
|
| 3 | +use super::{convert::*, error::*, layout::*, types::*}; |
3 | 4 | use ndarray::*; |
4 | 5 |
|
5 | | -use super::error::*; |
6 | | -use super::layout::*; |
7 | | -use super::types::*; |
8 | | - |
9 | 6 | pub use lapack::svddc::UVTFlag; |
10 | 7 |
|
11 | 8 | /// Singular-value decomposition of matrix (copying) by divide-and-conquer |
@@ -84,35 +81,22 @@ where |
84 | 81 | ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> { |
85 | 82 | let l = self.layout()?; |
86 | 83 | let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; |
87 | | - let (n, m) = l.size(); |
88 | | - let k = std::cmp::min(n, m); |
89 | | - let n = n as usize; |
90 | | - let m = m as usize; |
91 | | - let k = k as usize; |
| 84 | + let (m, n) = l.size(); |
| 85 | + let k = m.min(n); |
92 | 86 |
|
93 | 87 | let (u_col, vt_row) = match uvt_flag { |
94 | | - UVTFlag::Full => (n, m), |
| 88 | + UVTFlag::Full => (m, n), |
95 | 89 | UVTFlag::Some => (k, k), |
96 | 90 | UVTFlag::None => (0, 0), |
97 | 91 | }; |
98 | 92 |
|
99 | | - let u = svd_res.u.map(|u| { |
100 | | - assert_eq!(u.len(), n * u_col); |
101 | | - match l { |
102 | | - MatrixLayout::F { .. } => Array::from_shape_vec((n, u_col).f(), u), |
103 | | - MatrixLayout::C { .. } => Array::from_shape_vec((n, u_col), u), |
104 | | - } |
105 | | - .unwrap() |
106 | | - }); |
| 93 | + let u = svd_res |
| 94 | + .u |
| 95 | + .map(|u| into_matrix(l.resized(m, u_col), u).unwrap()); |
107 | 96 |
|
108 | | - let vt = svd_res.vt.map(|vt| { |
109 | | - assert_eq!(vt.len(), m * vt_row); |
110 | | - match l { |
111 | | - MatrixLayout::F { .. } => Array::from_shape_vec((vt_row, m).f(), vt), |
112 | | - MatrixLayout::C { .. } => Array::from_shape_vec((vt_row, m), vt), |
113 | | - } |
114 | | - .unwrap() |
115 | | - }); |
| 97 | + let vt = svd_res |
| 98 | + .vt |
| 99 | + .map(|vt| into_matrix(l.resized(vt_row, n), vt).unwrap()); |
116 | 100 |
|
117 | 101 | let s = ArrayBase::from(svd_res.s); |
118 | 102 | Ok((u, s, vt)) |
|
0 commit comments