Skip to content

Commit 7f2b725

Browse files
committed
Fix Solve::solve_h for complex inputs
1 parent fb9603d commit 7f2b725

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

lax/src/solve.rs

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,49 @@ macro_rules! impl_solve {
7575
ipiv: &Pivot,
7676
b: &mut [Self],
7777
) -> Result<()> {
78-
let t = match l {
78+
// If the array has C layout, then it needs to be handled
79+
// specially, since LAPACK expects a Fortran-layout array.
80+
// Reinterpreting a C layout array as Fortran layout is
81+
// equivalent to transposing it. So, we can handle the "no
82+
// transpose" and "transpose" cases by swapping to "transpose"
83+
// or "no transpose", respectively. For the "Hermite" case, we
84+
// can take advantage of the following:
85+
//
86+
// ```text
87+
// A^H x = b
88+
// ⟺ conj(A^T) x = b
89+
// ⟺ conj(conj(A^T) x) = conj(b)
90+
// ⟺ conj(conj(A^T)) conj(x) = conj(b)
91+
// ⟺ A^T conj(x) = conj(b)
92+
// ```
93+
//
94+
// So, we can handle this case by switching to "no transpose"
95+
// (which is equivalent to transposing the array since it will
96+
// be reinterpreted as Fortran layout) and applying the
97+
// elementwise conjugate to `x` and `b`.
98+
let (t, conj) = match l {
7999
MatrixLayout::C { .. } => match t {
80-
Transpose::No => Transpose::Transpose,
81-
Transpose::Transpose | Transpose::Hermite => Transpose::No,
100+
Transpose::No => (Transpose::Transpose, false),
101+
Transpose::Transpose => (Transpose::No, false),
102+
Transpose::Hermite => (Transpose::No, true),
82103
},
83-
_ => t,
104+
MatrixLayout::F { .. } => (t, false),
84105
};
85106
let (n, _) = l.size();
86107
let nrhs = 1;
87108
let ldb = l.lda();
88109
let mut info = 0;
110+
if conj {
111+
for b_elem in &mut *b {
112+
*b_elem = b_elem.conj();
113+
}
114+
}
89115
unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
116+
if conj {
117+
for b_elem in &mut *b {
118+
*b_elem = b_elem.conj();
119+
}
120+
}
90121
info.as_lapack_result()?;
91122
Ok(())
92123
}

0 commit comments

Comments
 (0)