55use super :: * ;
66use crate :: { error:: * , layout:: MatrixLayout } ;
77use cauchy:: * ;
8+ use num_traits:: { ToPrimitive , Zero } ;
89
910pub trait Solveh_ : Sized {
1011 /// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
@@ -28,13 +29,39 @@ macro_rules! impl_solveh {
2829 let ( n, _) = l. size( ) ;
2930 let mut ipiv = vec![ 0 ; n as usize ] ;
3031 if n == 0 {
31- // Work around bug in LAPACKE functions.
32- Ok ( ipiv)
33- } else {
34- $trf( l. lapacke_layout( ) , uplo as u8 , n, a, l. lda( ) , & mut ipiv)
35- . as_lapack_result( ) ?;
36- Ok ( ipiv)
32+ return Ok ( Vec :: new( ) ) ;
3733 }
34+
35+ // calc work size
36+ let mut info = 0 ;
37+ let mut work_size = [ Self :: zero( ) ] ;
38+ $trf(
39+ uplo as u8 ,
40+ n,
41+ a,
42+ l. lda( ) ,
43+ & mut ipiv,
44+ & mut work_size,
45+ -1 ,
46+ & mut info,
47+ ) ;
48+ info. as_lapack_result( ) ?;
49+
50+ // actual
51+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
52+ let mut work = vec![ Self :: zero( ) ; lwork] ;
53+ $trf(
54+ uplo as u8 ,
55+ n,
56+ a,
57+ l. lda( ) ,
58+ & mut ipiv,
59+ & mut work,
60+ lwork as i32 ,
61+ & mut info,
62+ ) ;
63+ info. as_lapack_result( ) ?;
64+ Ok ( ipiv)
3865 }
3966
4067 unsafe fn invh(
@@ -44,7 +71,10 @@ macro_rules! impl_solveh {
4471 ipiv: & Pivot ,
4572 ) -> Result <( ) > {
4673 let ( n, _) = l. size( ) ;
47- $tri( l. lapacke_layout( ) , uplo as u8 , n, a, l. lda( ) , ipiv) . as_lapack_result( ) ?;
74+ let mut info = 0 ;
75+ let mut work = vec![ Self :: zero( ) ; n as usize ] ;
76+ $tri( uplo as u8 , n, a, l. lda( ) , ipiv, & mut work, & mut info) ;
77+ info. as_lapack_result( ) ?;
4878 Ok ( ( ) )
4979 }
5080
@@ -56,30 +86,16 @@ macro_rules! impl_solveh {
5686 b: & mut [ Self ] ,
5787 ) -> Result <( ) > {
5888 let ( n, _) = l. size( ) ;
59- let nrhs = 1 ;
60- let ldb = match l {
61- MatrixLayout :: C { .. } => 1 ,
62- MatrixLayout :: F { .. } => n,
63- } ;
64- $trs(
65- l. lapacke_layout( ) ,
66- uplo as u8 ,
67- n,
68- nrhs,
69- a,
70- l. lda( ) ,
71- ipiv,
72- b,
73- ldb,
74- )
75- . as_lapack_result( ) ?;
89+ let mut info = 0 ;
90+ $trs( uplo as u8 , n, 1 , a, l. lda( ) , ipiv, b, n, & mut info) ;
91+ info. as_lapack_result( ) ?;
7692 Ok ( ( ) )
7793 }
7894 }
7995 } ;
8096} // impl_solveh!
8197
82- impl_solveh ! ( f64 , lapacke :: dsytrf, lapacke :: dsytri, lapacke :: dsytrs) ;
83- impl_solveh ! ( f32 , lapacke :: ssytrf, lapacke :: ssytri, lapacke :: ssytrs) ;
84- impl_solveh ! ( c64, lapacke :: zhetrf, lapacke :: zhetri, lapacke :: zhetrs) ;
85- impl_solveh ! ( c32, lapacke :: chetrf, lapacke :: chetri, lapacke :: chetrs) ;
98+ impl_solveh ! ( f64 , lapack :: dsytrf, lapack :: dsytri, lapack :: dsytrs) ;
99+ impl_solveh ! ( f32 , lapack :: ssytrf, lapack :: ssytri, lapack :: ssytrs) ;
100+ impl_solveh ! ( c64, lapack :: zhetrf, lapack :: zhetri, lapack :: zhetrs) ;
101+ impl_solveh ! ( c32, lapack :: chetrf, lapack :: chetri, lapack :: chetrs) ;
0 commit comments