6060//! // `a` and `b` have been moved, no longer valid
6161//! ```
6262
63- use ndarray:: { s , Array , Array1 , Array2 , ArrayBase , Axis , Data , DataMut , Dimension , Ix0 , Ix1 , Ix2 } ;
63+ use ndarray:: * ;
6464
6565use crate :: error:: * ;
6666use crate :: lapack:: least_squares:: * ;
@@ -352,7 +352,10 @@ where
352352 // we need a new rhs b/c it will be overwritten with the solution
353353 // for which we need `n` entries
354354 let k = rhs. shape ( ) [ 1 ] ;
355- let mut new_rhs = Array2 :: < E > :: zeros ( ( n, k) ) ;
355+ let mut new_rhs = match self . layout ( ) ? {
356+ MatrixLayout :: C { .. } => Array2 :: < E > :: zeros ( ( n, k) ) ,
357+ MatrixLayout :: F { .. } => Array2 :: < E > :: zeros ( ( n, k) . f ( ) ) ,
358+ } ;
356359 new_rhs. slice_mut ( s ! [ 0 ..m, ..] ) . assign ( rhs) ;
357360 compute_least_squares_nrhs ( self , & mut new_rhs)
358361 } else {
@@ -414,117 +417,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(
414417
415418#[ cfg( test) ]
416419mod tests {
417- use super :: * ;
420+ use crate :: { error :: LinalgError , * } ;
418421 use approx:: AbsDiffEq ;
419- use ndarray:: { ArcArray1 , ArcArray2 , Array1 , Array2 , CowArray } ;
420- use num_complex:: Complex ;
421-
422- //
423- // Test cases taken from the scipy test suite for the scipy lstsq function
424- // https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425- //
426- #[ test]
427- fn scipy_test_simple_exact ( ) {
428- let a = array ! [ [ 1. , 20. ] , [ -30. , 4. ] ] ;
429- let bs = vec ! [
430- array![ [ 1. , 0. ] , [ 0. , 1. ] ] ,
431- array![ [ 1. ] , [ 0. ] ] ,
432- array![ [ 2. , 1. ] , [ -30. , 4. ] ] ,
433- ] ;
434- for b in & bs {
435- let res = a. least_squares ( b) . unwrap ( ) ;
436- assert_eq ! ( res. rank, 2 ) ;
437- let b_hat = a. dot ( & res. solution ) ;
438- let rssq = ( b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum_axis ( Axis ( 0 ) ) ;
439- assert ! ( res
440- . residual_sum_of_squares
441- . unwrap( )
442- . abs_diff_eq( & rssq, 1e-12 ) ) ;
443- assert ! ( b_hat. abs_diff_eq( & b, 1e-12 ) ) ;
444- }
445- }
446-
447- #[ test]
448- fn scipy_test_simple_overdetermined ( ) {
449- let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
450- let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
451- let res = a. least_squares ( & b) . unwrap ( ) ;
452- assert_eq ! ( res. rank, 2 ) ;
453- let b_hat = a. dot ( & res. solution ) ;
454- let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
455- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
456- assert ! ( res
457- . solution
458- . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-12 ) ) ;
459- }
460-
461- #[ test]
462- fn scipy_test_simple_overdetermined_f32 ( ) {
463- let a: Array2 < f32 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
464- let b: Array1 < f32 > = array ! [ 1. , 2. , 3. ] ;
465- let res = a. least_squares ( & b) . unwrap ( ) ;
466- assert_eq ! ( res. rank, 2 ) ;
467- let b_hat = a. dot ( & res. solution ) ;
468- let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
469- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-6 ) ) ;
470- assert ! ( res
471- . solution
472- . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-6 ) ) ;
473- }
474-
475- fn c ( re : f64 , im : f64 ) -> Complex < f64 > {
476- Complex :: new ( re, im)
477- }
478-
479- #[ test]
480- fn scipy_test_simple_overdetermined_complex ( ) {
481- let a: Array2 < c64 > = array ! [
482- [ c( 1. , 2. ) , c( 2. , 0. ) ] ,
483- [ c( 4. , 0. ) , c( 5. , 0. ) ] ,
484- [ c( 3. , 0. ) , c( 4. , 0. ) ]
485- ] ;
486- let b: Array1 < c64 > = array ! [ c( 1. , 0. ) , c( 2. , 4. ) , c( 3. , 0. ) ] ;
487- let res = a. least_squares ( & b) . unwrap ( ) ;
488- assert_eq ! ( res. rank, 2 ) ;
489- let b_hat = a. dot ( & res. solution ) ;
490- let rssq = ( & b_hat - & b) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) ;
491- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
492- assert ! ( res. solution. abs_diff_eq(
493- & array![
494- c( -0.4831460674157303 , 0.258426966292135 ) ,
495- c( 0.921348314606741 , 0.292134831460674 )
496- ] ,
497- 1e-12
498- ) ) ;
499- }
500-
501- #[ test]
502- fn scipy_test_simple_underdetermined ( ) {
503- let a: Array2 < f64 > = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
504- let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
505- let res = a. least_squares ( & b) . unwrap ( ) ;
506- assert_eq ! ( res. rank, 2 ) ;
507- assert ! ( res. residual_sum_of_squares. is_none( ) ) ;
508- let expected = array ! [ -0.055555555555555 , 0.111111111111111 , 0.277777777777777 ] ;
509- assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
510- }
511-
512- /// This test case tests the underdetermined case for multiple right hand
513- /// sides. Adapted from scipy lstsq tests.
514- #[ test]
515- fn scipy_test_simple_underdetermined_nrhs ( ) {
516- let a: Array2 < f64 > = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
517- let b: Array2 < f64 > = array ! [ [ 1. , 1. ] , [ 2. , 2. ] ] ;
518- let res = a. least_squares ( & b) . unwrap ( ) ;
519- assert_eq ! ( res. rank, 2 ) ;
520- assert ! ( res. residual_sum_of_squares. is_none( ) ) ;
521- let expected = array ! [
522- [ -0.055555555555555 , -0.055555555555555 ] ,
523- [ 0.111111111111111 , 0.111111111111111 ] ,
524- [ 0.277777777777777 , 0.277777777777777 ]
525- ] ;
526- assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
527- }
422+ use ndarray:: * ;
528423
529424 //
530425 // Test that the different lest squares traits work as intended on the
@@ -554,23 +449,23 @@ mod tests {
554449 }
555450
556451 #[ test]
557- fn test_least_squares_on_arc ( ) {
452+ fn on_arc ( ) {
558453 let a: ArcArray2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] . into_shared ( ) ;
559454 let b: ArcArray1 < f64 > = array ! [ 1. , 2. , 3. ] . into_shared ( ) ;
560455 let res = a. least_squares ( & b) . unwrap ( ) ;
561456 assert_result ( & a, & b, & res) ;
562457 }
563458
564459 #[ test]
565- fn test_least_squares_on_cow ( ) {
460+ fn on_cow ( ) {
566461 let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
567462 let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
568463 let res = a. least_squares ( & b) . unwrap ( ) ;
569464 assert_result ( & a, & b, & res) ;
570465 }
571466
572467 #[ test]
573- fn test_least_squares_on_view ( ) {
468+ fn on_view ( ) {
574469 let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
575470 let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
576471 let av = a. view ( ) ;
@@ -580,7 +475,7 @@ mod tests {
580475 }
581476
582477 #[ test]
583- fn test_least_squares_on_view_mut ( ) {
478+ fn on_view_mut ( ) {
584479 let mut a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
585480 let mut b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
586481 let av = a. view_mut ( ) ;
@@ -590,7 +485,7 @@ mod tests {
590485 }
591486
592487 #[ test]
593- fn test_least_squares_into_on_owned ( ) {
488+ fn into_on_owned ( ) {
594489 let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
595490 let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
596491 let ac = a. clone ( ) ;
@@ -600,7 +495,7 @@ mod tests {
600495 }
601496
602497 #[ test]
603- fn test_least_squares_into_on_arc ( ) {
498+ fn into_on_arc ( ) {
604499 let a: ArcArray2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] . into_shared ( ) ;
605500 let b: ArcArray1 < f64 > = array ! [ 1. , 2. , 3. ] . into_shared ( ) ;
606501 let a2 = a. clone ( ) ;
@@ -610,7 +505,7 @@ mod tests {
610505 }
611506
612507 #[ test]
613- fn test_least_squares_into_on_cow ( ) {
508+ fn into_on_cow ( ) {
614509 let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
615510 let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
616511 let a2 = a. clone ( ) ;
@@ -620,7 +515,7 @@ mod tests {
620515 }
621516
622517 #[ test]
623- fn test_least_squares_in_place_on_owned ( ) {
518+ fn in_place_on_owned ( ) {
624519 let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
625520 let b = array ! [ 1. , 2. , 3. ] ;
626521 let mut a2 = a. clone ( ) ;
@@ -630,7 +525,7 @@ mod tests {
630525 }
631526
632527 #[ test]
633- fn test_least_squares_in_place_on_cow ( ) {
528+ fn in_place_on_cow ( ) {
634529 let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
635530 let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
636531 let mut a2 = a. clone ( ) ;
@@ -640,7 +535,7 @@ mod tests {
640535 }
641536
642537 #[ test]
643- fn test_least_squares_in_place_on_mut_view ( ) {
538+ fn in_place_on_mut_view ( ) {
644539 let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
645540 let b = array ! [ 1. , 2. , 3. ] ;
646541 let mut a2 = a. clone ( ) ;
@@ -651,95 +546,30 @@ mod tests {
651546 assert_result ( & a, & b, & res) ;
652547 }
653548
654- //
655- // Test cases taken from the netlib documentation at
656- // https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657- //
658- #[ test]
659- fn netlib_lapack_example_for_dgels_1 ( ) {
660- let a: Array2 < f64 > = array ! [
661- [ 1. , 1. , 1. ] ,
662- [ 2. , 3. , 4. ] ,
663- [ 3. , 5. , 2. ] ,
664- [ 4. , 2. , 5. ] ,
665- [ 5. , 4. , 3. ]
666- ] ;
667- let b: Array1 < f64 > = array ! [ -10. , 12. , 14. , 16. , 18. ] ;
668- let expected: Array1 < f64 > = array ! [ 2. , 1. , 1. ] ;
669- let result = a. least_squares ( & b) . unwrap ( ) ;
670- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
671-
672- let residual = b - a. dot ( & result. solution ) ;
673- let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
674- assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
675- }
676-
677- #[ test]
678- fn netlib_lapack_example_for_dgels_2 ( ) {
679- let a: Array2 < f64 > = array ! [
680- [ 1. , 1. , 1. ] ,
681- [ 2. , 3. , 4. ] ,
682- [ 3. , 5. , 2. ] ,
683- [ 4. , 2. , 5. ] ,
684- [ 5. , 4. , 3. ]
685- ] ;
686- let b: Array1 < f64 > = array ! [ -3. , 14. , 12. , 16. , 16. ] ;
687- let expected: Array1 < f64 > = array ! [ 1. , 1. , 2. ] ;
688- let result = a. least_squares ( & b) . unwrap ( ) ;
689- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
690-
691- let residual = b - a. dot ( & result. solution ) ;
692- let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
693- assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
694- }
695-
696- #[ test]
697- fn netlib_lapack_example_for_dgels_nrhs ( ) {
698- let a: Array2 < f64 > = array ! [
699- [ 1. , 1. , 1. ] ,
700- [ 2. , 3. , 4. ] ,
701- [ 3. , 5. , 2. ] ,
702- [ 4. , 2. , 5. ] ,
703- [ 5. , 4. , 3. ]
704- ] ;
705- let b: Array2 < f64 > = array ! [ [ -10. , -3. ] , [ 12. , 14. ] , [ 14. , 12. ] , [ 16. , 16. ] , [ 18. , 16. ] ] ;
706- let expected: Array2 < f64 > = array ! [ [ 2. , 1. ] , [ 1. , 1. ] , [ 1. , 2. ] ] ;
707- let result = a. least_squares ( & b) . unwrap ( ) ;
708- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
709-
710- let residual = & b - & a. dot ( & result. solution ) ;
711- let residual_ssq = residual. mapv ( |x| x. powi ( 2 ) ) . sum_axis ( Axis ( 0 ) ) ;
712- assert ! ( result
713- . residual_sum_of_squares
714- . unwrap( )
715- . abs_diff_eq( & residual_ssq, 1e-12 ) ) ;
716- }
717-
718549 //
719550 // Testing error cases
720551 //
721- use crate :: layout:: MatrixLayout ;
722552
723553 #[ test]
724- fn test_incompatible_shape_error_on_mismatching_num_rows ( ) {
554+ fn incompatible_shape_error_on_mismatching_num_rows ( ) {
725555 let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
726556 let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
727557 let res = a. least_squares ( & b) ;
728558 match res {
729- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lapack :: error:: Error :: InvalidShape ) => { }
559+ Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax :: error:: Error :: InvalidShape ) => { }
730560 _ => panic ! ( "Expected Err()" ) ,
731561 }
732562 }
733563
734564 #[ test]
735- fn test_incompatible_shape_error_on_mismatching_layout ( ) {
565+ fn incompatible_shape_error_on_mismatching_layout ( ) {
736566 let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
737567 let b = array ! [ [ 1. ] , [ 2. ] ] . t ( ) . to_owned ( ) ;
738568 assert_eq ! ( b. layout( ) . unwrap( ) , MatrixLayout :: F { col: 2 , lda: 1 } ) ;
739569
740570 let res = a. least_squares ( & b) ;
741571 match res {
742- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lapack :: error:: Error :: InvalidShape ) => { }
572+ Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax :: error:: Error :: InvalidShape ) => { }
743573 _ => panic ! ( "Expected Err()" ) ,
744574 }
745575 }
0 commit comments