@@ -17,119 +17,13 @@ use num_traits::{ToPrimitive, Zero};
1717/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$
1818/// using the output of LU decomposition.
1919///
20- pub trait Solve_ : Scalar + Sized {
21- /// Computes the LU decomposition of a general $m \times n$ matrix
22- /// with partial pivoting with row interchanges.
23- ///
24- /// Output
25- /// -------
26- /// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded.
27- /// - $P$ is returned as [Pivot]
28- ///
29- /// Error
30- /// ------
31- /// - if the matrix is singular
32- /// - On this case, `return_code` in [Error::LapackComputationalFailure] means
33- /// `return_code`-th diagonal element of $U$ becomes zero.
34- ///
35- /// LAPACK correspondance
36- /// ----------------------
37- ///
38- /// | f32 | f64 | c32 | c64 |
39- /// |:-------|:-------|:-------|:-------|
40- /// | sgetrf | dgetrf | cgetrf | zgetrf |
41- ///
20+ pub trait LuImpl : Scalar {
4221 fn lu ( l : MatrixLayout , a : & mut [ Self ] ) -> Result < Pivot > ;
43-
44- /// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition
45- ///
46- /// LAPACK correspondance
47- /// ----------------------
48- ///
49- /// | f32 | f64 | c32 | c64 |
50- /// |:-------|:-------|:-------|:-------|
51- /// | sgetri | dgetri | cgetri | zgetri |
52- ///
53- fn inv ( l : MatrixLayout , a : & mut [ Self ] , p : & Pivot ) -> Result < ( ) > ;
54-
55- /// Solve linear equations $Ax = b$ using the output of LU-decomposition
56- ///
57- /// LAPACK correspondance
58- /// ----------------------
59- ///
60- /// | f32 | f64 | c32 | c64 |
61- /// |:-------|:-------|:-------|:-------|
62- /// | sgetrs | dgetrs | cgetrs | zgetrs |
63- ///
64- fn solve ( l : MatrixLayout , t : Transpose , a : & [ Self ] , p : & Pivot , b : & mut [ Self ] ) -> Result < ( ) > ;
65- }
66-
67- pub struct InvWork < T : Scalar > {
68- pub layout : MatrixLayout ,
69- pub work : Vec < MaybeUninit < T > > ,
70- }
71-
72- pub trait InvWorkImpl : Sized {
73- type Elem : Scalar ;
74- fn new ( layout : MatrixLayout ) -> Result < Self > ;
75- fn calc ( & mut self , a : & mut [ Self :: Elem ] , p : & Pivot ) -> Result < ( ) > ;
76- }
77-
78- macro_rules! impl_inv_work {
79- ( $s: ty, $tri: path) => {
80- impl InvWorkImpl for InvWork <$s> {
81- type Elem = $s;
82-
83- fn new( layout: MatrixLayout ) -> Result <Self > {
84- let ( n, _) = layout. size( ) ;
85- let mut info = 0 ;
86- let mut work_size = [ Self :: Elem :: zero( ) ] ;
87- unsafe {
88- $tri(
89- & n,
90- std:: ptr:: null_mut( ) ,
91- & layout. lda( ) ,
92- std:: ptr:: null( ) ,
93- AsPtr :: as_mut_ptr( & mut work_size) ,
94- & ( -1 ) ,
95- & mut info,
96- )
97- } ;
98- info. as_lapack_result( ) ?;
99- let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
100- let work = vec_uninit( lwork) ;
101- Ok ( InvWork { layout, work } )
102- }
103-
104- fn calc( & mut self , a: & mut [ Self :: Elem ] , ipiv: & Pivot ) -> Result <( ) > {
105- let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
106- let mut info = 0 ;
107- unsafe {
108- $tri(
109- & self . layout. len( ) ,
110- AsPtr :: as_mut_ptr( a) ,
111- & self . layout. lda( ) ,
112- ipiv. as_ptr( ) ,
113- AsPtr :: as_mut_ptr( & mut self . work) ,
114- & lwork,
115- & mut info,
116- )
117- } ;
118- info. as_lapack_result( ) ?;
119- Ok ( ( ) )
120- }
121- }
122- } ;
12322}
12423
125- impl_inv_work ! ( c64, lapack_sys:: zgetri_) ;
126- impl_inv_work ! ( c32, lapack_sys:: cgetri_) ;
127- impl_inv_work ! ( f64 , lapack_sys:: dgetri_) ;
128- impl_inv_work ! ( f32 , lapack_sys:: sgetri_) ;
129-
130- macro_rules! impl_solve {
131- ( $scalar: ty, $getrf: path, $getri: path, $getrs: path) => {
132- impl Solve_ for $scalar {
24+ macro_rules! impl_lu {
25+ ( $scalar: ty, $getrf: path) => {
26+ impl LuImpl for $scalar {
13327 fn lu( l: MatrixLayout , a: & mut [ Self ] ) -> Result <Pivot > {
13428 let ( row, col) = l. size( ) ;
13529 assert_eq!( a. len( ) as i32 , row * col) ;
@@ -154,49 +48,22 @@ macro_rules! impl_solve {
15448 let ipiv = unsafe { ipiv. assume_init( ) } ;
15549 Ok ( ipiv)
15650 }
51+ }
52+ } ;
53+ }
15754
158- fn inv( l: MatrixLayout , a: & mut [ Self ] , ipiv: & Pivot ) -> Result <( ) > {
159- let ( n, _) = l. size( ) ;
160- if n == 0 {
161- // Do nothing for empty matrices.
162- return Ok ( ( ) ) ;
163- }
164-
165- // calc work size
166- let mut info = 0 ;
167- let mut work_size = [ Self :: zero( ) ] ;
168- unsafe {
169- $getri(
170- & n,
171- AsPtr :: as_mut_ptr( a) ,
172- & l. lda( ) ,
173- ipiv. as_ptr( ) ,
174- AsPtr :: as_mut_ptr( & mut work_size) ,
175- & ( -1 ) ,
176- & mut info,
177- )
178- } ;
179- info. as_lapack_result( ) ?;
180-
181- // actual
182- let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
183- let mut work: Vec <MaybeUninit <Self >> = vec_uninit( lwork) ;
184- unsafe {
185- $getri(
186- & l. len( ) ,
187- AsPtr :: as_mut_ptr( a) ,
188- & l. lda( ) ,
189- ipiv. as_ptr( ) ,
190- AsPtr :: as_mut_ptr( & mut work) ,
191- & ( lwork as i32 ) ,
192- & mut info,
193- )
194- } ;
195- info. as_lapack_result( ) ?;
55+ impl_lu ! ( c64, lapack_sys:: zgetrf_) ;
56+ impl_lu ! ( c32, lapack_sys:: cgetrf_) ;
57+ impl_lu ! ( f64 , lapack_sys:: dgetrf_) ;
58+ impl_lu ! ( f32 , lapack_sys:: sgetrf_) ;
19659
197- Ok ( ( ) )
198- }
60+ pub trait SolveImpl : Scalar {
61+ fn solve ( l : MatrixLayout , t : Transpose , a : & [ Self ] , p : & Pivot , b : & mut [ Self ] ) -> Result < ( ) > ;
62+ }
19963
64+ macro_rules! impl_solve {
65+ ( $scalar: ty, $getrs: path) => {
66+ impl SolveImpl for $scalar {
20067 fn solve(
20168 l: MatrixLayout ,
20269 t: Transpose ,
@@ -266,27 +133,70 @@ macro_rules! impl_solve {
266133 } ;
267134} // impl_solve!
268135
269- impl_solve ! (
270- f64 ,
271- lapack_sys:: dgetrf_,
272- lapack_sys:: dgetri_,
273- lapack_sys:: dgetrs_
274- ) ;
275- impl_solve ! (
276- f32 ,
277- lapack_sys:: sgetrf_,
278- lapack_sys:: sgetri_,
279- lapack_sys:: sgetrs_
280- ) ;
281- impl_solve ! (
282- c64,
283- lapack_sys:: zgetrf_,
284- lapack_sys:: zgetri_,
285- lapack_sys:: zgetrs_
286- ) ;
287- impl_solve ! (
288- c32,
289- lapack_sys:: cgetrf_,
290- lapack_sys:: cgetri_,
291- lapack_sys:: cgetrs_
292- ) ;
136+ impl_solve ! ( f64 , lapack_sys:: dgetrs_) ;
137+ impl_solve ! ( f32 , lapack_sys:: sgetrs_) ;
138+ impl_solve ! ( c64, lapack_sys:: zgetrs_) ;
139+ impl_solve ! ( c32, lapack_sys:: cgetrs_) ;
140+
141+ pub struct InvWork < T : Scalar > {
142+ pub layout : MatrixLayout ,
143+ pub work : Vec < MaybeUninit < T > > ,
144+ }
145+
146+ pub trait InvWorkImpl : Sized {
147+ type Elem : Scalar ;
148+ fn new ( layout : MatrixLayout ) -> Result < Self > ;
149+ fn calc ( & mut self , a : & mut [ Self :: Elem ] , p : & Pivot ) -> Result < ( ) > ;
150+ }
151+
152+ macro_rules! impl_inv_work {
153+ ( $s: ty, $tri: path) => {
154+ impl InvWorkImpl for InvWork <$s> {
155+ type Elem = $s;
156+
157+ fn new( layout: MatrixLayout ) -> Result <Self > {
158+ let ( n, _) = layout. size( ) ;
159+ let mut info = 0 ;
160+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
161+ unsafe {
162+ $tri(
163+ & n,
164+ std:: ptr:: null_mut( ) ,
165+ & layout. lda( ) ,
166+ std:: ptr:: null( ) ,
167+ AsPtr :: as_mut_ptr( & mut work_size) ,
168+ & ( -1 ) ,
169+ & mut info,
170+ )
171+ } ;
172+ info. as_lapack_result( ) ?;
173+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
174+ let work = vec_uninit( lwork) ;
175+ Ok ( InvWork { layout, work } )
176+ }
177+
178+ fn calc( & mut self , a: & mut [ Self :: Elem ] , ipiv: & Pivot ) -> Result <( ) > {
179+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
180+ let mut info = 0 ;
181+ unsafe {
182+ $tri(
183+ & self . layout. len( ) ,
184+ AsPtr :: as_mut_ptr( a) ,
185+ & self . layout. lda( ) ,
186+ ipiv. as_ptr( ) ,
187+ AsPtr :: as_mut_ptr( & mut self . work) ,
188+ & lwork,
189+ & mut info,
190+ )
191+ } ;
192+ info. as_lapack_result( ) ?;
193+ Ok ( ( ) )
194+ }
195+ }
196+ } ;
197+ }
198+
199+ impl_inv_work ! ( c64, lapack_sys:: zgetri_) ;
200+ impl_inv_work ! ( c32, lapack_sys:: cgetri_) ;
201+ impl_inv_work ! ( f64 , lapack_sys:: dgetri_) ;
202+ impl_inv_work ! ( f32 , lapack_sys:: sgetri_) ;
0 commit comments