Skip to content

Commit 842ad79

Browse files
committed
LuImpl, SolveImpl
1 parent 77235f4 commit 842ad79

File tree

2 files changed

+154
-177
lines changed

2 files changed

+154
-177
lines changed

lax/src/lib.rs

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ pub mod eigh;
9090
pub mod eigh_generalized;
9191
pub mod least_squares;
9292
pub mod qr;
93+
pub mod solve;
9394
pub mod svd;
9495
pub mod svddc;
9596

9697
mod alloc;
9798
mod cholesky;
9899
mod opnorm;
99100
mod rcond;
100-
mod solve;
101101
mod solveh;
102102
mod triangular;
103103
mod tridiagonal;
@@ -107,7 +107,6 @@ pub use self::flags::*;
107107
pub use self::least_squares::LeastSquaresOwned;
108108
pub use self::opnorm::*;
109109
pub use self::rcond::*;
110-
pub use self::solve::*;
111110
pub use self::solveh::*;
112111
pub use self::svd::{SvdOwned, SvdRef};
113112
pub use self::triangular::*;
@@ -122,7 +121,7 @@ pub type Pivot = Vec<i32>;
122121
#[cfg_attr(doc, katexit::katexit)]
123122
/// Trait for primitive types which implements LAPACK subroutines
124123
pub trait Lapack:
125-
OperatorNorm_ + Solve_ + Solveh_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_
124+
OperatorNorm_ + Solveh_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_
126125
{
127126
/// Compute right eigenvalue and eigenvectors for a general matrix
128127
fn eig(
@@ -181,6 +180,51 @@ pub trait Lapack:
181180
b_layout: MatrixLayout,
182181
b: &mut [Self],
183182
) -> Result<LeastSquaresOwned<Self>>;
183+
184+
/// Computes the LU decomposition of a general $m \times n$ matrix
185+
/// with partial pivoting with row interchanges.
186+
///
187+
/// Output
188+
/// -------
189+
/// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded.
190+
/// - $P$ is returned as [Pivot]
191+
///
192+
/// Error
193+
/// ------
194+
/// - if the matrix is singular
195+
/// - On this case, `return_code` in [Error::LapackComputationalFailure] means
196+
/// `return_code`-th diagonal element of $U$ becomes zero.
197+
///
198+
/// LAPACK correspondance
199+
/// ----------------------
200+
///
201+
/// | f32 | f64 | c32 | c64 |
202+
/// |:-------|:-------|:-------|:-------|
203+
/// | sgetrf | dgetrf | cgetrf | zgetrf |
204+
///
205+
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
206+
207+
/// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition
208+
///
209+
/// LAPACK correspondance
210+
/// ----------------------
211+
///
212+
/// | f32 | f64 | c32 | c64 |
213+
/// |:-------|:-------|:-------|:-------|
214+
/// | sgetri | dgetri | cgetri | zgetri |
215+
///
216+
fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
217+
218+
/// Solve linear equations $Ax = b$ using the output of LU-decomposition
219+
///
220+
/// LAPACK correspondance
221+
/// ----------------------
222+
///
223+
/// | f32 | f64 | c32 | c64 |
224+
/// |:-------|:-------|:-------|:-------|
225+
/// | sgetrs | dgetrs | cgetrs | zgetrs |
226+
///
227+
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
184228
}
185229

186230
macro_rules! impl_lapack {
@@ -276,6 +320,29 @@ macro_rules! impl_lapack {
276320
let work = LeastSquaresWork::<$s>::new(a_layout, b_layout)?;
277321
work.eval(a, b)
278322
}
323+
324+
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
325+
use solve::*;
326+
LuImpl::lu(l, a)
327+
}
328+
329+
fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()> {
330+
use solve::*;
331+
let mut work = InvWork::<$s>::new(l)?;
332+
work.calc(a, p)?;
333+
Ok(())
334+
}
335+
336+
fn solve(
337+
l: MatrixLayout,
338+
t: Transpose,
339+
a: &[Self],
340+
p: &Pivot,
341+
b: &mut [Self],
342+
) -> Result<()> {
343+
use solve::*;
344+
SolveImpl::solve(l, t, a, p, b)
345+
}
279346
}
280347
};
281348
}

lax/src/solve.rs

Lines changed: 84 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -17,119 +17,13 @@ use num_traits::{ToPrimitive, Zero};
1717
/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$
1818
/// using the output of LU decomposition.
1919
///
20-
pub trait Solve_: Scalar + Sized {
21-
/// Computes the LU decomposition of a general $m \times n$ matrix
22-
/// with partial pivoting with row interchanges.
23-
///
24-
/// Output
25-
/// -------
26-
/// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded.
27-
/// - $P$ is returned as [Pivot]
28-
///
29-
/// Error
30-
/// ------
31-
/// - if the matrix is singular
32-
/// - On this case, `return_code` in [Error::LapackComputationalFailure] means
33-
/// `return_code`-th diagonal element of $U$ becomes zero.
34-
///
35-
/// LAPACK correspondance
36-
/// ----------------------
37-
///
38-
/// | f32 | f64 | c32 | c64 |
39-
/// |:-------|:-------|:-------|:-------|
40-
/// | sgetrf | dgetrf | cgetrf | zgetrf |
41-
///
20+
pub trait LuImpl: Scalar {
4221
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
43-
44-
/// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition
45-
///
46-
/// LAPACK correspondance
47-
/// ----------------------
48-
///
49-
/// | f32 | f64 | c32 | c64 |
50-
/// |:-------|:-------|:-------|:-------|
51-
/// | sgetri | dgetri | cgetri | zgetri |
52-
///
53-
fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
54-
55-
/// Solve linear equations $Ax = b$ using the output of LU-decomposition
56-
///
57-
/// LAPACK correspondance
58-
/// ----------------------
59-
///
60-
/// | f32 | f64 | c32 | c64 |
61-
/// |:-------|:-------|:-------|:-------|
62-
/// | sgetrs | dgetrs | cgetrs | zgetrs |
63-
///
64-
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
65-
}
66-
67-
pub struct InvWork<T: Scalar> {
68-
pub layout: MatrixLayout,
69-
pub work: Vec<MaybeUninit<T>>,
70-
}
71-
72-
pub trait InvWorkImpl: Sized {
73-
type Elem: Scalar;
74-
fn new(layout: MatrixLayout) -> Result<Self>;
75-
fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>;
76-
}
77-
78-
macro_rules! impl_inv_work {
79-
($s:ty, $tri:path) => {
80-
impl InvWorkImpl for InvWork<$s> {
81-
type Elem = $s;
82-
83-
fn new(layout: MatrixLayout) -> Result<Self> {
84-
let (n, _) = layout.size();
85-
let mut info = 0;
86-
let mut work_size = [Self::Elem::zero()];
87-
unsafe {
88-
$tri(
89-
&n,
90-
std::ptr::null_mut(),
91-
&layout.lda(),
92-
std::ptr::null(),
93-
AsPtr::as_mut_ptr(&mut work_size),
94-
&(-1),
95-
&mut info,
96-
)
97-
};
98-
info.as_lapack_result()?;
99-
let lwork = work_size[0].to_usize().unwrap();
100-
let work = vec_uninit(lwork);
101-
Ok(InvWork { layout, work })
102-
}
103-
104-
fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
105-
let lwork = self.work.len().to_i32().unwrap();
106-
let mut info = 0;
107-
unsafe {
108-
$tri(
109-
&self.layout.len(),
110-
AsPtr::as_mut_ptr(a),
111-
&self.layout.lda(),
112-
ipiv.as_ptr(),
113-
AsPtr::as_mut_ptr(&mut self.work),
114-
&lwork,
115-
&mut info,
116-
)
117-
};
118-
info.as_lapack_result()?;
119-
Ok(())
120-
}
121-
}
122-
};
12322
}
12423

125-
impl_inv_work!(c64, lapack_sys::zgetri_);
126-
impl_inv_work!(c32, lapack_sys::cgetri_);
127-
impl_inv_work!(f64, lapack_sys::dgetri_);
128-
impl_inv_work!(f32, lapack_sys::sgetri_);
129-
130-
macro_rules! impl_solve {
131-
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
132-
impl Solve_ for $scalar {
24+
macro_rules! impl_lu {
25+
($scalar:ty, $getrf:path) => {
26+
impl LuImpl for $scalar {
13327
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
13428
let (row, col) = l.size();
13529
assert_eq!(a.len() as i32, row * col);
@@ -154,49 +48,22 @@ macro_rules! impl_solve {
15448
let ipiv = unsafe { ipiv.assume_init() };
15549
Ok(ipiv)
15650
}
51+
}
52+
};
53+
}
15754

158-
fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
159-
let (n, _) = l.size();
160-
if n == 0 {
161-
// Do nothing for empty matrices.
162-
return Ok(());
163-
}
164-
165-
// calc work size
166-
let mut info = 0;
167-
let mut work_size = [Self::zero()];
168-
unsafe {
169-
$getri(
170-
&n,
171-
AsPtr::as_mut_ptr(a),
172-
&l.lda(),
173-
ipiv.as_ptr(),
174-
AsPtr::as_mut_ptr(&mut work_size),
175-
&(-1),
176-
&mut info,
177-
)
178-
};
179-
info.as_lapack_result()?;
180-
181-
// actual
182-
let lwork = work_size[0].to_usize().unwrap();
183-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
184-
unsafe {
185-
$getri(
186-
&l.len(),
187-
AsPtr::as_mut_ptr(a),
188-
&l.lda(),
189-
ipiv.as_ptr(),
190-
AsPtr::as_mut_ptr(&mut work),
191-
&(lwork as i32),
192-
&mut info,
193-
)
194-
};
195-
info.as_lapack_result()?;
55+
impl_lu!(c64, lapack_sys::zgetrf_);
56+
impl_lu!(c32, lapack_sys::cgetrf_);
57+
impl_lu!(f64, lapack_sys::dgetrf_);
58+
impl_lu!(f32, lapack_sys::sgetrf_);
19659

197-
Ok(())
198-
}
60+
pub trait SolveImpl: Scalar {
61+
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
62+
}
19963

64+
macro_rules! impl_solve {
65+
($scalar:ty, $getrs:path) => {
66+
impl SolveImpl for $scalar {
20067
fn solve(
20168
l: MatrixLayout,
20269
t: Transpose,
@@ -266,27 +133,70 @@ macro_rules! impl_solve {
266133
};
267134
} // impl_solve!
268135

269-
impl_solve!(
270-
f64,
271-
lapack_sys::dgetrf_,
272-
lapack_sys::dgetri_,
273-
lapack_sys::dgetrs_
274-
);
275-
impl_solve!(
276-
f32,
277-
lapack_sys::sgetrf_,
278-
lapack_sys::sgetri_,
279-
lapack_sys::sgetrs_
280-
);
281-
impl_solve!(
282-
c64,
283-
lapack_sys::zgetrf_,
284-
lapack_sys::zgetri_,
285-
lapack_sys::zgetrs_
286-
);
287-
impl_solve!(
288-
c32,
289-
lapack_sys::cgetrf_,
290-
lapack_sys::cgetri_,
291-
lapack_sys::cgetrs_
292-
);
136+
impl_solve!(f64, lapack_sys::dgetrs_);
137+
impl_solve!(f32, lapack_sys::sgetrs_);
138+
impl_solve!(c64, lapack_sys::zgetrs_);
139+
impl_solve!(c32, lapack_sys::cgetrs_);
140+
141+
pub struct InvWork<T: Scalar> {
142+
pub layout: MatrixLayout,
143+
pub work: Vec<MaybeUninit<T>>,
144+
}
145+
146+
pub trait InvWorkImpl: Sized {
147+
type Elem: Scalar;
148+
fn new(layout: MatrixLayout) -> Result<Self>;
149+
fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>;
150+
}
151+
152+
macro_rules! impl_inv_work {
153+
($s:ty, $tri:path) => {
154+
impl InvWorkImpl for InvWork<$s> {
155+
type Elem = $s;
156+
157+
fn new(layout: MatrixLayout) -> Result<Self> {
158+
let (n, _) = layout.size();
159+
let mut info = 0;
160+
let mut work_size = [Self::Elem::zero()];
161+
unsafe {
162+
$tri(
163+
&n,
164+
std::ptr::null_mut(),
165+
&layout.lda(),
166+
std::ptr::null(),
167+
AsPtr::as_mut_ptr(&mut work_size),
168+
&(-1),
169+
&mut info,
170+
)
171+
};
172+
info.as_lapack_result()?;
173+
let lwork = work_size[0].to_usize().unwrap();
174+
let work = vec_uninit(lwork);
175+
Ok(InvWork { layout, work })
176+
}
177+
178+
fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
179+
let lwork = self.work.len().to_i32().unwrap();
180+
let mut info = 0;
181+
unsafe {
182+
$tri(
183+
&self.layout.len(),
184+
AsPtr::as_mut_ptr(a),
185+
&self.layout.lda(),
186+
ipiv.as_ptr(),
187+
AsPtr::as_mut_ptr(&mut self.work),
188+
&lwork,
189+
&mut info,
190+
)
191+
};
192+
info.as_lapack_result()?;
193+
Ok(())
194+
}
195+
}
196+
};
197+
}
198+
199+
impl_inv_work!(c64, lapack_sys::zgetri_);
200+
impl_inv_work!(c32, lapack_sys::cgetri_);
201+
impl_inv_work!(f64, lapack_sys::dgetri_);
202+
impl_inv_work!(f32, lapack_sys::sgetri_);

0 commit comments

Comments
 (0)