|
3 | 3 | use super::*; |
4 | 4 | use crate::{error::*, layout::MatrixLayout}; |
5 | 5 | use cauchy::*; |
6 | | -use num_traits::Zero; |
| 6 | +use num_traits::{ToPrimitive, Zero}; |
7 | 7 |
|
8 | | -/// Wraps `*getrf`, `*getri`, and `*getrs` |
9 | 8 | pub trait Solve_: Scalar + Sized { |
10 | 9 | /// Computes the LU factorization of a general `m x n` matrix `a` using |
11 | 10 | /// partial pivoting with row interchanges. |
12 | 11 | /// |
13 | | - /// If the result matches `Err(LinalgError::Lapack(LapackError { |
14 | | - /// return_code )) if return_code > 0`, then `U[(return_code-1, |
15 | | - /// return_code-1)]` is exactly zero. The factorization has been completed, |
16 | | - /// but the factor `U` is exactly singular, and division by zero will occur |
17 | | - /// if it is used to solve a system of equations. |
18 | | - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>; |
19 | | - unsafe fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; |
20 | | - /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. |
| 12 | + /// $ PA = LU $ |
21 | 13 | /// |
22 | | - /// `anorm` should be the 1-norm of the matrix `a`. |
23 | | - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>; |
24 | | - unsafe fn solve( |
25 | | - l: MatrixLayout, |
26 | | - t: Transpose, |
27 | | - a: &[Self], |
28 | | - p: &Pivot, |
29 | | - b: &mut [Self], |
30 | | - ) -> Result<()>; |
| 14 | + /// Error |
| 15 | + /// ------ |
| 16 | + /// - `LapackComputationalFailure { return_code }` when the matrix is singular |
| 17 | + /// - Division by zero will occur if it is used to solve a system of equations |
| 18 | + /// because `U[(return_code-1, return_code-1)]` is exactly zero. |
| 19 | + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>; |
| 20 | + |
| 21 | + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; |
| 22 | + |
| 23 | + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; |
31 | 24 | } |
32 | 25 |
|
33 | 26 | macro_rules! impl_solve { |
34 | | - ($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => { |
| 27 | + ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { |
35 | 28 | impl Solve_ for $scalar { |
36 | | - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> { |
| 29 | + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> { |
37 | 30 | let (row, col) = l.size(); |
| 31 | + assert_eq!(a.len() as i32, row * col); |
| 32 | + if row == 0 || col == 0 { |
| 33 | + // Do nothing for empty matrix |
| 34 | + return Ok(Vec::new()); |
| 35 | + } |
38 | 36 | let k = ::std::cmp::min(row, col); |
39 | 37 | let mut ipiv = vec![0; k as usize]; |
40 | | - $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv).as_lapack_result()?; |
| 38 | + let mut info = 0; |
| 39 | + unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) }; |
| 40 | + info.as_lapack_result()?; |
41 | 41 | Ok(ipiv) |
42 | 42 | } |
43 | 43 |
|
44 | | - unsafe fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { |
| 44 | + fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { |
45 | 45 | let (n, _) = l.size(); |
46 | | - $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; |
47 | | - Ok(()) |
48 | | - } |
49 | 46 |
|
50 | | - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> { |
51 | | - let (n, _) = l.size(); |
52 | | - let mut rcond = Self::Real::zero(); |
53 | | - $gecon( |
54 | | - l.lapacke_layout(), |
55 | | - NormType::One as u8, |
56 | | - n, |
57 | | - a, |
58 | | - l.lda(), |
59 | | - anorm, |
60 | | - &mut rcond, |
61 | | - ) |
62 | | - .as_lapack_result()?; |
63 | | - Ok(rcond) |
| 47 | + // calc work size |
| 48 | + let mut info = 0; |
| 49 | + let mut work_size = [Self::zero()]; |
| 50 | + unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) }; |
| 51 | + info.as_lapack_result()?; |
| 52 | + |
| 53 | + // actual |
| 54 | + let lwork = work_size[0].to_usize().unwrap(); |
| 55 | + let mut work = vec![Self::zero(); lwork]; |
| 56 | + unsafe { |
| 57 | + $getri( |
| 58 | + l.len(), |
| 59 | + a, |
| 60 | + l.lda(), |
| 61 | + ipiv, |
| 62 | + &mut work, |
| 63 | + lwork as i32, |
| 64 | + &mut info, |
| 65 | + ) |
| 66 | + }; |
| 67 | + info.as_lapack_result()?; |
| 68 | + |
| 69 | + Ok(()) |
64 | 70 | } |
65 | 71 |
|
66 | | - unsafe fn solve( |
| 72 | + fn solve( |
67 | 73 | l: MatrixLayout, |
68 | 74 | t: Transpose, |
69 | 75 | a: &[Self], |
70 | 76 | ipiv: &Pivot, |
71 | 77 | b: &mut [Self], |
72 | 78 | ) -> Result<()> { |
| 79 | + let t = match l { |
| 80 | + MatrixLayout::C { .. } => match t { |
| 81 | + Transpose::No => Transpose::Transpose, |
| 82 | + Transpose::Transpose | Transpose::Hermite => Transpose::No, |
| 83 | + }, |
| 84 | + _ => t, |
| 85 | + }; |
73 | 86 | let (n, _) = l.size(); |
74 | 87 | let nrhs = 1; |
75 | | - let ldb = 1; |
76 | | - $getrs( |
77 | | - l.lapacke_layout(), |
78 | | - t as u8, |
79 | | - n, |
80 | | - nrhs, |
81 | | - a, |
82 | | - l.lda(), |
83 | | - ipiv, |
84 | | - b, |
85 | | - ldb, |
86 | | - ) |
87 | | - .as_lapack_result()?; |
| 88 | + let ldb = l.lda(); |
| 89 | + let mut info = 0; |
| 90 | + unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; |
| 91 | + info.as_lapack_result()?; |
88 | 92 | Ok(()) |
89 | 93 | } |
90 | 94 | } |
91 | 95 | }; |
92 | 96 | } // impl_solve! |
93 | 97 |
|
94 | | -impl_solve!( |
95 | | - f64, |
96 | | - lapacke::dgetrf, |
97 | | - lapacke::dgetri, |
98 | | - lapacke::dgecon, |
99 | | - lapacke::dgetrs |
100 | | -); |
101 | | -impl_solve!( |
102 | | - f32, |
103 | | - lapacke::sgetrf, |
104 | | - lapacke::sgetri, |
105 | | - lapacke::sgecon, |
106 | | - lapacke::sgetrs |
107 | | -); |
108 | | -impl_solve!( |
109 | | - c64, |
110 | | - lapacke::zgetrf, |
111 | | - lapacke::zgetri, |
112 | | - lapacke::zgecon, |
113 | | - lapacke::zgetrs |
114 | | -); |
115 | | -impl_solve!( |
116 | | - c32, |
117 | | - lapacke::cgetrf, |
118 | | - lapacke::cgetri, |
119 | | - lapacke::cgecon, |
120 | | - lapacke::cgetrs |
121 | | -); |
| 98 | +impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs); |
| 99 | +impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs); |
| 100 | +impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs); |
| 101 | +impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs); |
0 commit comments