55use ndarray:: linalg:: general_mat_mul;
66use ndarray:: linalg:: kron;
77use ndarray:: prelude:: * ;
8+ #[ cfg( feature = "approx" ) ]
9+ use ndarray:: Order ;
810use ndarray:: { rcarr1, rcarr2} ;
911use ndarray:: { Data , LinalgScalar } ;
1012use ndarray:: { Ix , Ixs } ;
11- use num_traits :: Zero ;
13+ use ndarray_gen :: array_builder :: ArrayBuilder ;
1214
1315use approx:: assert_abs_diff_eq;
1416use defmac:: defmac;
17+ use num_traits:: Num ;
18+ use num_traits:: Zero ;
1519
1620fn test_oper ( op : & str , a : & [ f32 ] , b : & [ f32 ] , c : & [ f32 ] )
1721{
@@ -271,31 +275,20 @@ fn product()
271275 }
272276}
273277
274- fn range_mat ( m : Ix , n : Ix ) -> Array2 < f32 >
278+ fn range_mat < A : Num + Copy > ( m : Ix , n : Ix ) -> Array2 < A >
275279{
276- Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
277- . into_shape_with_order ( ( m, n) )
278- . unwrap ( )
279- }
280-
281- fn range_mat64 ( m : Ix , n : Ix ) -> Array2 < f64 >
282- {
283- Array :: linspace ( 0. , ( m * n) as f64 - 1. , m * n)
284- . into_shape_with_order ( ( m, n) )
285- . unwrap ( )
280+ ArrayBuilder :: new ( ( m, n) ) . build ( )
286281}
287282
288283#[ cfg( feature = "approx" ) ]
289284fn range1_mat64 ( m : Ix ) -> Array1 < f64 >
290285{
291- Array :: linspace ( 0. , m as f64 - 1. , m )
286+ ArrayBuilder :: new ( m ) . build ( )
292287}
293288
294289fn range_i32 ( m : Ix , n : Ix ) -> Array2 < i32 >
295290{
296- Array :: from_iter ( 0 ..( m * n) as i32 )
297- . into_shape_with_order ( ( m, n) )
298- . unwrap ( )
291+ ArrayBuilder :: new ( ( m, n) ) . build ( )
299292}
300293
301294// simple, slow, correct (hopefully) mat mul
@@ -332,8 +325,8 @@ where
332325fn mat_mul ( )
333326{
334327 let ( m, n, k) = ( 8 , 8 , 8 ) ;
335- let a = range_mat ( m, n) ;
336- let b = range_mat ( n, k) ;
328+ let a = range_mat :: < f32 > ( m, n) ;
329+ let b = range_mat :: < f32 > ( n, k) ;
337330 let mut b = b / 4. ;
338331 {
339332 let mut c = b. column_mut ( 0 ) ;
@@ -351,8 +344,8 @@ fn mat_mul()
351344 assert_eq ! ( ab, af. dot( & bf) ) ;
352345
353346 let ( m, n, k) = ( 10 , 5 , 11 ) ;
354- let a = range_mat ( m, n) ;
355- let b = range_mat ( n, k) ;
347+ let a = range_mat :: < f32 > ( m, n) ;
348+ let b = range_mat :: < f32 > ( n, k) ;
356349 let mut b = b / 4. ;
357350 {
358351 let mut c = b. column_mut ( 0 ) ;
@@ -370,8 +363,8 @@ fn mat_mul()
370363 assert_eq ! ( ab, af. dot( & bf) ) ;
371364
372365 let ( m, n, k) = ( 10 , 8 , 1 ) ;
373- let a = range_mat ( m, n) ;
374- let b = range_mat ( n, k) ;
366+ let a = range_mat :: < f32 > ( m, n) ;
367+ let b = range_mat :: < f32 > ( n, k) ;
375368 let mut b = b / 4. ;
376369 {
377370 let mut c = b. column_mut ( 0 ) ;
@@ -395,8 +388,8 @@ fn mat_mul()
395388fn mat_mul_order ( )
396389{
397390 let ( m, n, k) = ( 8 , 8 , 8 ) ;
398- let a = range_mat ( m, n) ;
399- let b = range_mat ( n, k) ;
391+ let a = range_mat :: < f32 > ( m, n) ;
392+ let b = range_mat :: < f32 > ( n, k) ;
400393 let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
401394 let mut bf = Array :: zeros ( b. dim ( ) . f ( ) ) ;
402395 af. assign ( & a) ;
@@ -415,8 +408,8 @@ fn mat_mul_order()
415408fn mat_mul_shape_mismatch ( )
416409{
417410 let ( m, k, k2, n) = ( 8 , 8 , 9 , 8 ) ;
418- let a = range_mat ( m, k) ;
419- let b = range_mat ( k2, n) ;
411+ let a = range_mat :: < f32 > ( m, k) ;
412+ let b = range_mat :: < f32 > ( k2, n) ;
420413 a. dot ( & b) ;
421414}
422415
@@ -426,9 +419,9 @@ fn mat_mul_shape_mismatch()
426419fn mat_mul_shape_mismatch_2 ( )
427420{
428421 let ( m, k, k2, n) = ( 8 , 8 , 8 , 8 ) ;
429- let a = range_mat ( m, k) ;
430- let b = range_mat ( k2, n) ;
431- let mut c = range_mat ( m, n + 1 ) ;
422+ let a = range_mat :: < f32 > ( m, k) ;
423+ let b = range_mat :: < f32 > ( k2, n) ;
424+ let mut c = range_mat :: < f32 > ( m, n + 1 ) ;
432425 general_mat_mul ( 1. , & a, & b, 1. , & mut c) ;
433426}
434427
@@ -438,7 +431,7 @@ fn mat_mul_shape_mismatch_2()
438431fn mat_mul_broadcast ( )
439432{
440433 let ( m, n, k) = ( 16 , 16 , 16 ) ;
441- let a = range_mat ( m, n) ;
434+ let a = range_mat :: < f32 > ( m, n) ;
442435 let x1 = 1. ;
443436 let x = Array :: from ( vec ! [ x1] ) ;
444437 let b0 = x. broadcast ( ( n, k) ) . unwrap ( ) ;
@@ -458,8 +451,8 @@ fn mat_mul_broadcast()
458451fn mat_mul_rev ( )
459452{
460453 let ( m, n, k) = ( 16 , 16 , 16 ) ;
461- let a = range_mat ( m, n) ;
462- let b = range_mat ( n, k) ;
454+ let a = range_mat :: < f32 > ( m, n) ;
455+ let b = range_mat :: < f32 > ( n, k) ;
463456 let mut rev = Array :: zeros ( b. dim ( ) ) ;
464457 let mut rev = rev. slice_mut ( s ! [ ..; -1 , ..] ) ;
465458 rev. assign ( & b) ;
@@ -488,8 +481,8 @@ fn mat_mut_zero_len()
488481 }
489482 }
490483 } ) ;
491- mat_mul_zero_len ! ( range_mat) ;
492- mat_mul_zero_len ! ( range_mat64 ) ;
484+ mat_mul_zero_len ! ( range_mat:: < f32 > ) ;
485+ mat_mul_zero_len ! ( range_mat :: < f64 > ) ;
493486 mat_mul_zero_len ! ( range_i32) ;
494487}
495488
@@ -528,9 +521,9 @@ fn scaled_add_2()
528521 for & s1 in & [ 1 , 2 , -1 , -2 ] {
529522 for & s2 in & [ 1 , 2 , -1 , -2 ] {
530523 for & ( m, k, n, q) in & sizes {
531- let mut a = range_mat64 ( m, k) ;
524+ let mut a = range_mat :: < f64 > ( m, k) ;
532525 let mut answer = a. clone ( ) ;
533- let c = range_mat64 ( n, q) ;
526+ let c = range_mat :: < f64 > ( n, q) ;
534527
535528 {
536529 let mut av = a. slice_mut ( s ! [ ..; s1, ..; s2] ) ;
@@ -570,7 +563,7 @@ fn scaled_add_3()
570563 for & s1 in & [ 1 , 2 , -1 , -2 ] {
571564 for & s2 in & [ 1 , 2 , -1 , -2 ] {
572565 for & ( m, k, n, q) in & sizes {
573- let mut a = range_mat64 ( m, k) ;
566+ let mut a = range_mat :: < f64 > ( m, k) ;
574567 let mut answer = a. clone ( ) ;
575568 let cdim = if n == 1 { vec ! [ q] } else { vec ! [ n, q] } ;
576569 let cslice: Vec < SliceInfoElem > = if n == 1 {
@@ -582,7 +575,7 @@ fn scaled_add_3()
582575 ]
583576 } ;
584577
585- let c = range_mat64 ( n, q) . into_shape_with_order ( cdim) . unwrap ( ) ;
578+ let c = range_mat :: < f64 > ( n, q) . into_shape_with_order ( cdim) . unwrap ( ) ;
586579
587580 {
588581 let mut av = a. slice_mut ( s ! [ ..; s1, ..; s2] ) ;
@@ -619,9 +612,9 @@ fn gen_mat_mul()
619612 for & s1 in & [ 1 , 2 , -1 , -2 ] {
620613 for & s2 in & [ 1 , 2 , -1 , -2 ] {
621614 for & ( m, k, n) in & sizes {
622- let a = range_mat64 ( m, k) ;
623- let b = range_mat64 ( k, n) ;
624- let mut c = range_mat64 ( m, n) ;
615+ let a = range_mat :: < f64 > ( m, k) ;
616+ let b = range_mat :: < f64 > ( k, n) ;
617+ let mut c = range_mat :: < f64 > ( m, n) ;
625618 let mut answer = c. clone ( ) ;
626619
627620 {
@@ -645,11 +638,11 @@ fn gen_mat_mul()
645638#[ test]
646639fn gemm_64_1_f ( )
647640{
648- let a = range_mat64 ( 64 , 64 ) . reversed_axes ( ) ;
641+ let a = range_mat :: < f64 > ( 64 , 64 ) . reversed_axes ( ) ;
649642 let ( m, n) = a. dim ( ) ;
650643 // m x n times n x 1 == m x 1
651- let x = range_mat64 ( n, 1 ) ;
652- let mut y = range_mat64 ( m, 1 ) ;
644+ let x = range_mat :: < f64 > ( n, 1 ) ;
645+ let mut y = range_mat :: < f64 > ( m, 1 ) ;
653646 let answer = reference_mat_mul ( & a, & x) + & y;
654647 general_mat_mul ( 1.0 , & a, & x, 1.0 , & mut y) ;
655648 approx:: assert_relative_eq!( y, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
@@ -728,11 +721,8 @@ fn gen_mat_vec_mul()
728721 for & s1 in & [ 1 , 2 , -1 , -2 ] {
729722 for & s2 in & [ 1 , 2 , -1 , -2 ] {
730723 for & ( m, k) in & sizes {
731- for & rev in & [ false , true ] {
732- let mut a = range_mat64 ( m, k) ;
733- if rev {
734- a = a. reversed_axes ( ) ;
735- }
724+ for order in [ Order :: C , Order :: F ] {
725+ let a = ArrayBuilder :: new ( ( m, k) ) . memory_order ( order) . build ( ) ;
736726 let ( m, k) = a. dim ( ) ;
737727 let b = range1_mat64 ( k) ;
738728 let mut c = range1_mat64 ( m) ;
@@ -794,11 +784,8 @@ fn vec_mat_mul()
794784 for & s1 in & [ 1 , 2 , -1 , -2 ] {
795785 for & s2 in & [ 1 , 2 , -1 , -2 ] {
796786 for & ( m, n) in & sizes {
797- for & rev in & [ false , true ] {
798- let mut b = range_mat64 ( m, n) ;
799- if rev {
800- b = b. reversed_axes ( ) ;
801- }
787+ for order in [ Order :: C , Order :: F ] {
788+ let b = ArrayBuilder :: new ( ( m, n) ) . memory_order ( order) . build ( ) ;
802789 let ( m, n) = b. dim ( ) ;
803790 let a = range1_mat64 ( m) ;
804791 let mut c = range1_mat64 ( n) ;
0 commit comments