@@ -738,4 +738,44 @@ mod tests {
738738 . unwrap( )
739739 . abs_diff_eq( & residual_ssq, 1e-12 ) ) ;
740740 }
741+
742+ ///////////////////////////////////////////////////////////////////////////
743+ /// Testing error cases
744+ ///////////////////////////////////////////////////////////////////////////
745+ use ndarray:: ErrorKind ;
746+ use crate :: layout:: MatrixLayout ;
747+
748+ #[ test]
749+ fn test_incompatible_shape_error_on_mismatching_num_rows ( ) {
750+ let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
751+ let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
752+ let res = a. least_squares ( & b) ;
753+ match res {
754+ Err ( err) =>
755+ match err {
756+ LinalgError :: Shape ( shape_error) =>
757+ assert_eq ! ( shape_error. kind( ) , ErrorKind :: IncompatibleShape ) ,
758+ _ => panic ! ( "Expected ShapeError" )
759+ } ,
760+ _ => panic ! ( "Expected Err()" )
761+ }
762+ }
763+
764+ #[ test]
765+ fn test_incompatible_shape_error_on_mismatching_layout ( ) {
766+ let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
767+ let b = array ! [ [ 1. ] , [ 2. ] ] . t ( ) . to_owned ( ) ;
768+ assert_eq ! ( b. layout( ) . unwrap( ) , MatrixLayout :: F ( ( 2 , 1 ) ) ) ;
769+
770+ let res = a. least_squares ( & b) ;
771+ match res {
772+ Err ( err) =>
773+ match err {
774+ LinalgError :: Shape ( shape_error) =>
775+ assert_eq ! ( shape_error. kind( ) , ErrorKind :: IncompatibleShape ) ,
776+ _ => panic ! ( "Expected ShapeError" )
777+ } ,
778+ _ => panic ! ( "Expected Err()" )
779+ }
780+ }
741781}
0 commit comments