33use super :: * ;
44use crate :: { error:: * , layout:: MatrixLayout } ;
55use cauchy:: * ;
6- use num_traits:: Zero ;
6+ use num_traits:: { ToPrimitive , Zero } ;
77
88pub trait Solve_ : Scalar + Sized {
99 /// Computes the LU factorization of a general `m x n` matrix `a` using
@@ -14,59 +14,55 @@ pub trait Solve_: Scalar + Sized {
1414 /// Error
1515 /// ------
1616 /// - `LapackComputationalFailure { return_code }` when the matrix is singular
17- /// - `U[(return_code-1, return_code-1)]` is exactly zero.
18- /// - Division by zero will occur if it is used to solve a system of equations .
17+ /// - Division by zero will occur if it is used to solve a system of equations
18+ /// because `U[(return_code-1, return_code-1)]` is exactly zero .
1919 fn lu ( l : MatrixLayout , a : & mut [ Self ] ) -> Result < Pivot > ;
2020
2121 fn inv ( l : MatrixLayout , a : & mut [ Self ] , p : & Pivot ) -> Result < ( ) > ;
2222
23- /// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
24- ///
25- /// `anorm` should be the 1-norm of the matrix `a`.
26- fn rcond ( l : MatrixLayout , a : & [ Self ] , anorm : Self :: Real ) -> Result < Self :: Real > ;
27-
2823 fn solve ( l : MatrixLayout , t : Transpose , a : & [ Self ] , p : & Pivot , b : & mut [ Self ] ) -> Result < ( ) > ;
2924}
3025
3126macro_rules! impl_solve {
32- ( $scalar: ty, $getrf: path, $getri: path, $gecon : path , $ getrs: path) => {
27+ ( $scalar: ty, $getrf: path, $getri: path, $getrs: path) => {
3328 impl Solve_ for $scalar {
3429 fn lu( l: MatrixLayout , a: & mut [ Self ] ) -> Result <Pivot > {
3530 let ( row, col) = l. size( ) ;
31+ assert_eq!( a. len( ) as i32 , row * col) ;
3632 let k = :: std:: cmp:: min( row, col) ;
3733 let mut ipiv = vec![ 0 ; k as usize ] ;
38- unsafe {
39- $getrf( l. lapacke_layout( ) , row, col, a, l. lda( ) , & mut ipiv)
40- . as_lapack_result( ) ?;
41- }
34+ let mut info = 0 ;
35+ unsafe { $getrf( l. lda( ) , l. len( ) , a, l. lda( ) , & mut ipiv, & mut info) } ;
36+ info. as_lapack_result( ) ?;
4237 Ok ( ipiv)
4338 }
4439
4540 fn inv( l: MatrixLayout , a: & mut [ Self ] , ipiv: & Pivot ) -> Result <( ) > {
4641 let ( n, _) = l. size( ) ;
47- unsafe {
48- $getri( l. lapacke_layout( ) , n, a, l. lda( ) , ipiv) . as_lapack_result( ) ?;
49- }
50- Ok ( ( ) )
51- }
5242
53- fn rcond( l: MatrixLayout , a: & [ Self ] , anorm: Self :: Real ) -> Result <Self :: Real > {
54- let ( n, _) = l. size( ) ;
55- let mut rcond = Self :: Real :: zero( ) ;
43+ // calc work size
44+ let mut info = 0 ;
45+ let mut work_size = [ Self :: zero( ) ] ;
46+ unsafe { $getri( n, a, l. lda( ) , ipiv, & mut work_size, -1 , & mut info) } ;
47+ info. as_lapack_result( ) ?;
48+
49+ // actual
50+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
51+ let mut work = vec![ Self :: zero( ) ; lwork] ;
5652 unsafe {
57- $gecon(
58- l. lapacke_layout( ) ,
59- NormType :: One as u8 ,
60- n,
53+ $getri(
54+ l. len( ) ,
6155 a,
6256 l. lda( ) ,
63- anorm,
64- & mut rcond,
57+ ipiv,
58+ & mut work,
59+ lwork as i32 ,
60+ & mut info,
6561 )
66- }
67- . as_lapack_result( ) ?;
62+ } ;
63+ info . as_lapack_result( ) ?;
6864
69- Ok ( rcond )
65+ Ok ( ( ) )
7066 }
7167
7268 fn solve(
@@ -76,54 +72,26 @@ macro_rules! impl_solve {
7672 ipiv: & Pivot ,
7773 b: & mut [ Self ] ,
7874 ) -> Result <( ) > {
75+ let t = match l {
76+ MatrixLayout :: C { .. } => match t {
77+ Transpose :: No => Transpose :: Transpose ,
78+ Transpose :: Transpose | Transpose :: Hermite => Transpose :: No ,
79+ } ,
80+ _ => t,
81+ } ;
7982 let ( n, _) = l. size( ) ;
8083 let nrhs = 1 ;
81- let ldb = 1 ;
82- unsafe {
83- $getrs(
84- l. lapacke_layout( ) ,
85- t as u8 ,
86- n,
87- nrhs,
88- a,
89- l. lda( ) ,
90- ipiv,
91- b,
92- ldb,
93- )
94- . as_lapack_result( ) ?;
95- }
84+ let ldb = l. lda( ) ;
85+ let mut info = 0 ;
86+ unsafe { $getrs( t as u8 , n, nrhs, a, l. lda( ) , ipiv, b, ldb, & mut info) } ;
87+ info. as_lapack_result( ) ?;
9688 Ok ( ( ) )
9789 }
9890 }
9991 } ;
10092} // impl_solve!
10193
102- impl_solve ! (
103- f64 ,
104- lapacke:: dgetrf,
105- lapacke:: dgetri,
106- lapacke:: dgecon,
107- lapacke:: dgetrs
108- ) ;
109- impl_solve ! (
110- f32 ,
111- lapacke:: sgetrf,
112- lapacke:: sgetri,
113- lapacke:: sgecon,
114- lapacke:: sgetrs
115- ) ;
116- impl_solve ! (
117- c64,
118- lapacke:: zgetrf,
119- lapacke:: zgetri,
120- lapacke:: zgecon,
121- lapacke:: zgetrs
122- ) ;
123- impl_solve ! (
124- c32,
125- lapacke:: cgetrf,
126- lapacke:: cgetri,
127- lapacke:: cgecon,
128- lapacke:: cgetrs
129- ) ;
94+ impl_solve ! ( f64 , lapack:: dgetrf, lapack:: dgetri, lapack:: dgetrs) ;
95+ impl_solve ! ( f32 , lapack:: sgetrf, lapack:: sgetri, lapack:: sgetrs) ;
96+ impl_solve ! ( c64, lapack:: zgetrf, lapack:: zgetri, lapack:: zgetrs) ;
97+ impl_solve ! ( c32, lapack:: cgetrf, lapack:: cgetri, lapack:: cgetrs) ;
0 commit comments