@@ -817,17 +817,16 @@ mod blas_tests {
817817}
818818
819819#[ allow( dead_code) ]
820- fn general_outer_to_dyn < Sa , Sb , I , F , T > (
820+ fn general_outer_to_dyn < Sa , Sb , F , T > (
821821 a : & ArrayBase < Sa , IxDyn > ,
822- b : & ArrayBase < Sb , I > ,
822+ b : & ArrayBase < Sb , IxDyn > ,
823823 f : F ,
824824) -> ArrayD < T >
825825where
826826 T : Copy ,
827827 Sa : Data < Elem = T > ,
828828 Sb : Data < Elem = T > ,
829- I : Dimension ,
830- F : Fn ( ArrayViewMut < T , IxDyn > , T , & ArrayBase < Sb , I > ) -> ( ) ,
829+ F : Fn ( T , T ) -> T ,
831830{
832831 //Iterators on the shapes, compelted by 1s
833832 let a_shape_iter = a. shape ( ) . iter ( ) . chain ( [ 1 ] . iter ( ) . cycle ( ) ) ;
@@ -843,25 +842,24 @@ where
843842 unsafe {
844843 let mut res: ArrayD < T > = ArrayBase :: uninitialized ( res_dim) ;
845844 let res_chunks = res. exact_chunks_mut ( b. shape ( ) ) ;
846- Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| f ( res_chunk, a_elem, b) ) ;
845+ Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| {
846+ Zip :: from ( res_chunk)
847+ . and ( b)
848+ . apply ( |res_elem, & b_elem| * res_elem = f ( a_elem, b_elem) )
849+ } ) ;
847850 res
848851 }
849852}
850853
851854#[ allow( dead_code, clippy:: type_repetition_in_bounds) ]
852- fn kron_to_dyn < Sa , I , Sb , T > ( a : & ArrayBase < Sa , IxDyn > , b : & ArrayBase < Sb , I > ) -> Array < T , IxDyn >
855+ fn kron_to_dyn < Sa , Sb , T > ( a : & ArrayBase < Sa , IxDyn > , b : & ArrayBase < Sb , IxDyn > ) -> Array < T , IxDyn >
853856where
854857 T : Copy ,
855858 Sa : Data < Elem = T > ,
856859 Sb : Data < Elem = T > ,
857- I : Dimension ,
858- T : crate :: ScalarOperand + std:: ops:: MulAssign ,
859- for < ' a > & ' a ArrayBase < Sb , I > : std:: ops:: Mul < T , Output = Array < T , I > > ,
860+ T : crate :: ScalarOperand + std:: ops:: Mul < Output = T > ,
860861{
861- general_outer_to_dyn ( a, b, |mut res, x, a| {
862- res. assign ( a) ;
863- res *= x
864- } )
862+ general_outer_to_dyn ( a, b, std:: ops:: Mul :: mul)
865863}
866864
867865#[ allow( dead_code) ]
@@ -875,7 +873,7 @@ where
875873 Sa : Data < Elem = T > ,
876874 Sb : Data < Elem = T > ,
877875 I : Dimension ,
878- F : Fn ( ArrayViewMut < T , I > , T , & ArrayBase < Sb , I > ) -> ( ) ,
876+ F : Fn ( T , T ) -> T ,
879877{
880878 let mut res_dim = a. raw_dim ( ) ;
881879 let mut res_dim_view = res_dim. as_array_view_mut ( ) ;
@@ -884,7 +882,11 @@ where
884882 unsafe {
885883 let mut res: Array < T , I > = ArrayBase :: uninitialized ( res_dim) ;
886884 let res_chunks = res. exact_chunks_mut ( b. raw_dim ( ) ) ;
887- Zip :: from ( res_chunks) . and ( a) . apply ( |r_c, & x| f ( r_c, x, b) ) ;
885+ Zip :: from ( res_chunks) . and ( a) . apply ( |res_chunk, & a_elem| {
886+ Zip :: from ( res_chunk)
887+ . and ( b)
888+ . apply ( |r_elem, & b_elem| * r_elem = f ( a_elem, b_elem) )
889+ } ) ;
888890 res
889891 }
890892}
@@ -896,13 +898,9 @@ where
896898 Sa : Data < Elem = T > ,
897899 Sb : Data < Elem = T > ,
898900 I : Dimension ,
899- T : crate :: ScalarOperand + std:: ops:: MulAssign ,
900- for < ' a > & ' a ArrayBase < Sb , I > : std:: ops:: Mul < T , Output = Array < T , I > > ,
901+ T : crate :: ScalarOperand + std:: ops:: Mul < Output = T > ,
901902{
902- general_outer_same_size ( a, b, |mut res, x, a| {
903- res. assign ( & a) ;
904- res *= x
905- } )
903+ general_outer_same_size ( a, b, std:: ops:: Mul :: mul)
906904}
907905
908906#[ cfg( test) ]
@@ -922,7 +920,7 @@ mod kron_test {
922920 [ [ 110 , 0 , 7 ] , [ 523 , 21 , -12 ] ]
923921 ] ;
924922 let res1 = kron_same_size ( & a, & b) ;
925- let res2 = kron_to_dyn ( & a. clone ( ) . into_dyn ( ) , & b) ;
923+ let res2 = kron_to_dyn ( & a. clone ( ) . into_dyn ( ) , & b. clone ( ) . into_dyn ( ) ) ;
926924 assert_eq ! ( res1. clone( ) . into_dyn( ) , res2) ;
927925 for a0 in 0 ..a. len_of ( Axis ( 0 ) ) {
928926 for a1 in 0 ..a. len_of ( Axis ( 1 ) ) {
0 commit comments