Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ndarray-linalg/src/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,17 @@ where
{
type EigVal = Array1<A::Real>;

/// Solves the generalized eigenvalue problem.
///
/// # Panics
///
/// Panics if the shapes of the matrices are different.
fn eigh_inplace(&mut self, uplo: UPLO) -> Result<(Self::EigVal, &mut Self)> {
assert_eq!(
self.0.shape(),
self.1.shape(),
"The shapes of the matrices must be identical.",
);
let layout = self.0.square_layout()?;
// XXX Force layout to be Fortran (see #146)
match layout {
Expand Down
64 changes: 64 additions & 0 deletions ndarray-linalg/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,59 +77,103 @@ pub use lax::{Pivot, Transpose};
pub trait Solve<A: Scalar> {
/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of columns
/// of `A`.
fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_inplace(&mut b)?;
Ok(b)
}

/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of columns
/// of `A`.
fn solve_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
self.solve_inplace(&mut b)?;
Ok(b)
}

/// Solves a system of linear equations `A * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of columns
/// of `A`.
fn solve_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
) -> Result<&'a mut ArrayBase<S, Ix1>>;

/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of rows of
/// `A`.
fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_t_inplace(&mut b)?;
Ok(b)
}

/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of rows of
/// `A`.
fn solve_t_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
self.solve_t_inplace(&mut b)?;
Ok(b)
}

/// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of rows of
/// `A`.
fn solve_t_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
) -> Result<&'a mut ArrayBase<S, Ix1>>;

/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of rows of
/// `A`.
fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_h_inplace(&mut b)?;
Ok(b)
}
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of rows of
/// `A`.
fn solve_h_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
Expand All @@ -139,6 +183,11 @@ pub trait Solve<A: Scalar> {
}
/// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b`
/// is the argument, and `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of rows of
/// `A`.
fn solve_h_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
Expand Down Expand Up @@ -167,6 +216,11 @@ where
where
Sb: DataMut<Elem = A>,
{
assert_eq!(
rhs.len(),
self.a.len_of(Axis(1)),
"The length of `rhs` must be compatible with the shape of the factored matrix.",
);
A::solve(
self.a.square_layout()?,
Transpose::No,
Expand All @@ -183,6 +237,11 @@ where
where
Sb: DataMut<Elem = A>,
{
assert_eq!(
rhs.len(),
self.a.len_of(Axis(0)),
"The length of `rhs` must be compatible with the shape of the factored matrix.",
);
A::solve(
self.a.square_layout()?,
Transpose::Transpose,
Expand All @@ -199,6 +258,11 @@ where
where
Sb: DataMut<Elem = A>,
{
assert_eq!(
rhs.len(),
self.a.len_of(Axis(0)),
"The length of `rhs` must be compatible with the shape of the factored matrix.",
);
A::solve(
self.a.square_layout()?,
Transpose::Hermite,
Expand Down
22 changes: 22 additions & 0 deletions ndarray-linalg/src/solveh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,42 @@ pub trait SolveH<A: Scalar> {
/// Solves a system of linear equations `A * x = b` with Hermitian (or real
/// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
/// `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of columns
/// of `A`.
fn solveh<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solveh_inplace(&mut b)?;
Ok(b)
}

/// Solves a system of linear equations `A * x = b` with Hermitian (or real
/// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
/// `x` is the successful result.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of columns
/// of `A`.
fn solveh_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
self.solveh_inplace(&mut b)?;
Ok(b)
}

/// Solves a system of linear equations `A * x = b` with Hermitian (or real
/// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and
/// `x` is the successful result. The value of `x` is also assigned to the
/// argument.
///
/// # Panics
///
/// Panics if the length of `b` is not the equal to the number of columns
/// of `A`.
fn solveh_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
Expand All @@ -113,6 +130,11 @@ where
where
Sb: DataMut<Elem = A>,
{
assert_eq!(
rhs.len(),
self.a.len_of(Axis(1)),
"The length of `rhs` must be compatible with the shape of the factored matrix.",
);
A::solveh(
self.a.square_layout()?,
UPLO::Upper,
Expand Down
8 changes: 8 additions & 0 deletions ndarray-linalg/tests/eigh.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
use ndarray::*;
use ndarray_linalg::*;

#[should_panic]
#[test]
fn eigh_generalized_shape_mismatch() {
let a = Array2::<f64>::eye(3);
let b = Array2::<f64>::eye(2);
let _ = (a, b).eigh_inplace(UPLO::Upper);
}

#[test]
fn fixed() {
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
Expand Down
34 changes: 34 additions & 0 deletions ndarray-linalg/tests/solve.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
use ndarray::*;
use ndarray_linalg::*;

#[should_panic]
#[test]
fn solve_shape_mismatch() {
let a: Array2<f64> = random((3, 3));
let b: Array1<f64> = random(2);
let _ = a.solve_into(b);
}

#[test]
fn solve_random() {
let a: Array2<f64> = random((3, 3));
Expand All @@ -10,6 +18,14 @@ fn solve_random() {
assert_close_l2!(&x, &y, 1e-7);
}

#[should_panic]
#[test]
fn solve_t_shape_mismatch() {
let a: Array2<f64> = random((3, 3).f());
let b: Array1<f64> = random(4);
let _ = a.solve_into(b);
}

#[test]
fn solve_random_t() {
let a: Array2<f64> = random((3, 3).f());
Expand All @@ -19,6 +35,15 @@ fn solve_random_t() {
assert_close_l2!(&x, &y, 1e-7);
}

#[should_panic]
#[test]
fn solve_factorized_shape_mismatch() {
let a: Array2<f64> = random((3, 3));
let b: Array1<f64> = random(4);
let f = a.factorize_into().unwrap();
let _ = f.solve_into(b);
}

#[test]
fn solve_factorized() {
let a: Array2<f64> = random((3, 3));
Expand All @@ -29,6 +54,15 @@ fn solve_factorized() {
assert_close_l2!(&x, &ans, 1e-7);
}

#[should_panic]
#[test]
fn solve_factorized_t_shape_mismatch() {
let a: Array2<f64> = random((3, 3).f());
let b: Array1<f64> = random(4);
let f = a.factorize_into().unwrap();
let _ = f.solve_into(b);
}

#[test]
fn solve_factorized_t() {
let a: Array2<f64> = random((3, 3).f());
Expand Down
34 changes: 34 additions & 0 deletions ndarray-linalg/tests/solveh.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
use ndarray::*;
use ndarray_linalg::*;

#[should_panic]
#[test]
fn solveh_shape_mismatch() {
let a: Array2<f64> = random_hpd(3);
let b: Array1<f64> = random(2);
let _ = a.solveh_into(b);
}

#[should_panic]
#[test]
fn factorizeh_solveh_shape_mismatch() {
let a: Array2<f64> = random_hpd(3);
let b: Array1<f64> = random(2);
let f = a.factorizeh_into().unwrap();
let _ = f.solveh_into(b);
}

#[test]
fn solveh_random() {
let a: Array2<f64> = random_hpd(3);
Expand All @@ -15,6 +32,23 @@ fn solveh_random() {
assert_close_l2!(&x, &y, 1e-7);
}

#[should_panic]
#[test]
fn solveh_t_shape_mismatch() {
let a: Array2<f64> = random_hpd(3).reversed_axes();
let b: Array1<f64> = random(2);
let _ = a.solveh_into(b);
}

#[should_panic]
#[test]
fn factorizeh_solveh_t_shape_mismatch() {
let a: Array2<f64> = random_hpd(3).reversed_axes();
let b: Array1<f64> = random(2);
let f = a.factorizeh_into().unwrap();
let _ = f.solveh_into(b);
}

#[test]
fn solveh_random_t() {
let a: Array2<f64> = random_hpd(3).reversed_axes();
Expand Down