|
1 | 1 | //! Cholesky decomposition |
2 | 2 |
|
3 | 3 | use super::*; |
4 | | -use crate::{error::*, layout::MatrixLayout}; |
| 4 | +use crate::{error::*, layout::*}; |
5 | 5 | use cauchy::*; |
6 | 6 |
|
7 | 7 | pub trait Cholesky_: Sized { |
8 | 8 | /// Cholesky: wrapper of `*potrf` |
9 | 9 | /// |
10 | 10 | /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** |
11 | | - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; |
| 11 | + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; |
| 12 | + |
12 | 13 | /// Wrapper of `*potri` |
13 | 14 | /// |
14 | 15 | /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** |
15 | | - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; |
| 16 | + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; |
| 17 | + |
16 | 18 | /// Wrapper of `*potrs` |
17 | | - unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) |
18 | | - -> Result<()>; |
| 19 | + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; |
19 | 20 | } |
20 | 21 |
|
21 | 22 | macro_rules! impl_cholesky { |
22 | 23 | ($scalar:ty, $trf:path, $tri:path, $trs:path) => { |
23 | 24 | impl Cholesky_ for $scalar { |
24 | | - unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { |
| 25 | + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { |
25 | 26 | let (n, _) = l.size(); |
26 | | - $trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?; |
| 27 | + if matches!(l, MatrixLayout::C { .. }) { |
| 28 | + square_transpose(l, a); |
| 29 | + } |
| 30 | + let mut info = 0; |
| 31 | + unsafe { |
| 32 | + $trf(uplo as u8, n, a, n, &mut info); |
| 33 | + } |
| 34 | + info.as_lapack_result()?; |
| 35 | + if matches!(l, MatrixLayout::C { .. }) { |
| 36 | + square_transpose(l, a); |
| 37 | + } |
27 | 38 | Ok(()) |
28 | 39 | } |
29 | 40 |
|
30 | | - unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { |
| 41 | + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { |
31 | 42 | let (n, _) = l.size(); |
32 | | - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?; |
| 43 | + if matches!(l, MatrixLayout::C { .. }) { |
| 44 | + square_transpose(l, a); |
| 45 | + } |
| 46 | + let mut info = 0; |
| 47 | + unsafe { |
| 48 | + $tri(uplo as u8, n, a, l.lda(), &mut info); |
| 49 | + } |
| 50 | + info.as_lapack_result()?; |
| 51 | + if matches!(l, MatrixLayout::C { .. }) { |
| 52 | + square_transpose(l, a); |
| 53 | + } |
33 | 54 | Ok(()) |
34 | 55 | } |
35 | 56 |
|
36 | | - unsafe fn solve_cholesky( |
| 57 | + fn solve_cholesky( |
37 | 58 | l: MatrixLayout, |
38 | | - uplo: UPLO, |
| 59 | + mut uplo: UPLO, |
39 | 60 | a: &[Self], |
40 | 61 | b: &mut [Self], |
41 | 62 | ) -> Result<()> { |
42 | 63 | let (n, _) = l.size(); |
43 | 64 | let nrhs = 1; |
44 | | - let ldb = 1; |
45 | | - $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb) |
46 | | - .as_lapack_result()?; |
| 65 | + let mut info = 0; |
| 66 | + if matches!(l, MatrixLayout::C { .. }) { |
| 67 | + uplo = uplo.t(); |
| 68 | + for val in b.iter_mut() { |
| 69 | + *val = val.conj(); |
| 70 | + } |
| 71 | + } |
| 72 | + unsafe { |
| 73 | + $trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info); |
| 74 | + } |
| 75 | + info.as_lapack_result()?; |
| 76 | + if matches!(l, MatrixLayout::C { .. }) { |
| 77 | + for val in b.iter_mut() { |
| 78 | + *val = val.conj(); |
| 79 | + } |
| 80 | + } |
47 | 81 | Ok(()) |
48 | 82 | } |
49 | 83 | } |
50 | 84 | }; |
51 | 85 | } // end macro_rules |
52 | 86 |
|
53 | | -impl_cholesky!(f64, lapacke::dpotrf, lapacke::dpotri, lapacke::dpotrs); |
54 | | -impl_cholesky!(f32, lapacke::spotrf, lapacke::spotri, lapacke::spotrs); |
55 | | -impl_cholesky!(c64, lapacke::zpotrf, lapacke::zpotri, lapacke::zpotrs); |
56 | | -impl_cholesky!(c32, lapacke::cpotrf, lapacke::cpotri, lapacke::cpotrs); |
| 87 | +impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs); |
| 88 | +impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs); |
| 89 | +impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs); |
| 90 | +impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs); |
0 commit comments