Skip to content

Commit 7ff280e

Browse files
committed
broken wip
1 parent 5abb5df commit 7ff280e

File tree

4 files changed

+214
-5
lines changed

4 files changed

+214
-5
lines changed

linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,141 @@
5151
b .ok
5252

5353
.move:
54-
b .unsupported
54+
lsr w7, w6, 16
55+
and w7, w7, 0xff // w7 is dst reg
56+
lsr w6, w6, 24
57+
and w6, w6, 0xff // w6 is src
58+
add w7, w7, w6, LSL#2 // 4bits DDSS
59+
adr x4, .move_jmp_table
60+
add x4, x4, x7, LSL#2
61+
br x4
62+
63+
.move_jmp_table:
64+
b .inner_loop // a to a
65+
b .move_a_b
66+
b .move_a_c
67+
b .unsupported // a <- d
68+
b .move_b_a
69+
b .inner_loop // b <- b
70+
b .move_b_c
71+
b .unsupported // b <- d
72+
b .move_c_a
73+
b .move_c_b
74+
b .inner_loop // c <- c
75+
b .unsupported // c <- d
76+
b .unsupported // a <- d
77+
b .unsupported // b <- d
78+
b .unsupported // c <- d
79+
b .unsupported // d <- d
80+
81+
.move_a_b:
82+
and v0.16b, v8.16b, v8.16b
83+
and v1.16b, v9.16b, v9.16b
84+
and v2.16b, v10.16b, v10.16b
85+
and v3.16b, v11.16b, v11.16b
86+
and v4.16b, v12.16b, v12.16b
87+
and v5.16b, v13.16b, v13.16b
88+
and v6.16b, v14.16b, v14.16b
89+
and v7.16b, v15.16b, v15.16b
90+
b .inner_loop
91+
92+
.move_a_c:
93+
and v0.16b, v16.16b, v16.16b
94+
and v1.16b, v17.16b, v17.16b
95+
and v2.16b, v18.16b, v18.16b
96+
and v3.16b, v19.16b, v19.16b
97+
and v4.16b, v20.16b, v20.16b
98+
and v5.16b, v21.16b, v21.16b
99+
and v6.16b, v22.16b, v22.16b
100+
and v7.16b, v23.16b, v23.16b
101+
b .inner_loop
102+
103+
.move_b_a:
104+
and v8.16b , v0.16b, v0.16b
105+
and v9.16b , v1.16b, v1.16b
106+
and v10.16b, v2.16b, v2.16b
107+
and v11.16b, v3.16b, v3.16b
108+
and v12.16b, v4.16b, v4.16b
109+
and v13.16b, v5.16b, v5.16b
110+
and v14.16b, v6.16b, v6.16b
111+
and v15.16b, v7.16b, v7.16b
112+
b .inner_loop
113+
114+
.move_b_c:
115+
and v8.16b , v16.16b, v16.16b
116+
and v9.16b , v17.16b, v17.16b
117+
and v10.16b, v18.16b, v18.16b
118+
and v11.16b, v19.16b, v19.16b
119+
and v12.16b, v20.16b, v20.16b
120+
and v13.16b, v21.16b, v21.16b
121+
and v14.16b, v22.16b, v22.16b
122+
and v15.16b, v23.16b, v23.16b
123+
b .inner_loop
124+
125+
.move_c_a:
126+
and v16.16b, v0.16b, v0.16b
127+
and v17.16b, v1.16b, v1.16b
128+
and v18.16b, v2.16b, v2.16b
129+
and v19.16b, v3.16b, v3.16b
130+
and v20.16b, v4.16b, v4.16b
131+
and v21.16b, v5.16b, v5.16b
132+
and v22.16b, v6.16b, v6.16b
133+
and v23.16b, v7.16b, v7.16b
134+
b .inner_loop
135+
136+
.move_c_b:
137+
and v16.16b, v8.16b , v8.16b
138+
and v17.16b, v9.16b , v9.16b
139+
and v18.16b, v10.16b, v10.16b
140+
and v19.16b, v11.16b, v11.16b
141+
and v20.16b, v12.16b, v12.16b
142+
and v21.16b, v13.16b, v13.16b
143+
and v22.16b, v14.16b, v14.16b
144+
and v23.16b, v15.16b, v15.16b
145+
b .inner_loop
146+
55147
.load:
56-
b .unsupported
148+
add x5, x5, 4
149+
ins v24.s[0], w3
150+
lsr w7, w6, 16
151+
and w7, w7, 0xff
152+
adr x4, .load_jmp_table
153+
add x4, x4, x7, LSL#2
154+
br x4
155+
.load_jmp_table:
156+
b .load_a
157+
b .load_b
158+
b .load_c
159+
.load_a:
160+
dup v0.4s, v24.s[0]
161+
dup v1.4s, v24.s[0]
162+
dup v2.4s, v24.s[0]
163+
dup v3.4s, v24.s[0]
164+
dup v4.4s, v24.s[0]
165+
dup v5.4s, v24.s[0]
166+
dup v6.4s, v24.s[0]
167+
dup v7.4s, v24.s[0]
168+
b .inner_loop
169+
.load_b:
170+
dup v8.4s, v24.s[0]
171+
dup v9.4s, v24.s[0]
172+
dup v10.4s, v24.s[0]
173+
dup v11.4s, v24.s[0]
174+
dup v12.4s, v24.s[0]
175+
dup v13.4s, v24.s[0]
176+
dup v14.4s, v24.s[0]
177+
dup v15.4s, v24.s[0]
178+
b .inner_loop
179+
.load_c:
180+
dup v16.4s, v24.s[0]
181+
dup v17.4s, v24.s[0]
182+
dup v18.4s, v24.s[0]
183+
dup v19.4s, v24.s[0]
184+
dup v20.4s, v24.s[0]
185+
dup v21.4s, v24.s[0]
186+
dup v22.4s, v24.s[0]
187+
dup v23.4s, v24.s[0]
188+
b .inner_loop
57189
.abs:
58190
b .unsupported
59191
.recip:

linalg/src/frame/activations.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use super::element_wise_helper::run_over_slice_with_alignment;
1010
pub mod definitions;
1111
pub mod reference;
1212
#[macro_use]
13+
#[cfg(test)]
1314
pub mod tests;
1415

1516
#[derive(Clone, Debug, PartialEq)]

linalg/src/frame/activations/definitions.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ pub fn threshold_relu<T: LADatum>(alpha: T) -> Program<T> {
4141
}
4242
}
4343

44+
pub fn hard_sigmoid<T: LADatum>(alpha: T, beta: T) -> Program<T> {
45+
Program {
46+
#[rustfmt::skip]
47+
ops: vec![
48+
MulConst(alpha),
49+
AddConst(beta),
50+
MinConst(T::one()),
51+
MaxConst(T::zero()),
52+
],
53+
}
54+
}
55+
4456
pub fn softsign<T: LADatum>() -> Program<T> {
4557
Program {
4658
#[rustfmt::skip]
@@ -54,7 +66,7 @@ pub fn softsign<T: LADatum>() -> Program<T> {
5466
}
5567
}
5668

57-
pub fn hardswish<T: LADatum>() -> Program<T> {
69+
pub fn hard_swish<T: LADatum>() -> Program<T> {
5870
let one_sixth = T::one() / (T::one() + T::one() + T::one() + T::one() + T::one() + T::one());
5971
let one_half = T::one() / (T::one() + T::one());
6072
Program {

linalg/src/frame/activations/tests.rs

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use crate::LADatum;
22

3-
use super::{ActivationKer, Op, Program};
3+
use super::{ActivationKer, Op, Program, RegisterId};
44
use Op::*;
5+
use proptest::prelude::*;
56

67
pub fn noop<T: LADatum>() -> Program<T> {
78
Program { ops: vec![] }
@@ -28,6 +29,14 @@ pub fn run_kernel_test<TI: LADatum, K: ActivationKer<TI>>(
2829
expected.close_enough(&tensor, true).unwrap();
2930
}
3031

32+
impl Arbitrary for RegisterId {
33+
type Parameters = ();
34+
type Strategy = BoxedStrategy<RegisterId>;
35+
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
36+
proptest::prop_oneof![Just(RegisterId::A), Just(RegisterId::B), Just(RegisterId::C)].boxed()
37+
}
38+
}
39+
3140
#[macro_export]
3241
macro_rules! act_tests {
3342
($cond:expr, $ker:ty, $ti:ty) => {
@@ -37,7 +46,8 @@ macro_rules! act_tests {
3746
use $crate::frame::activations::ActivationKer;
3847
use $crate::frame::activations::tests::*;
3948
use $crate::frame::activations::Op::*;
40-
use num_traits::Zero;
49+
use $crate::frame::activations::RegisterId;
50+
use num_traits::{Zero, One};
4151
use proptest::prelude::*;
4252
use proptest::collection::vec;
4353

@@ -56,6 +66,38 @@ macro_rules! act_tests {
5666
}
5767
}
5868

69+
#[test]
70+
fn load_a_prop(x in x_strat(), konst in any::<$ti>()) {
71+
if $cond {
72+
run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::A, konst)], |_| konst);
73+
}
74+
}
75+
76+
#[test]
77+
fn load_b_prop(x in x_strat(), konst in any::<$ti>()) {
78+
if $cond {
79+
run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::B, konst)], |x| x);
80+
}
81+
}
82+
83+
#[test]
84+
fn load_c_prop(x in x_strat(), konst in any::<$ti>()) {
85+
if $cond {
86+
run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::C, konst)], |x| x);
87+
}
88+
}
89+
90+
#[test]
91+
fn move_b_to_a_prop(x in x_strat(), konst in any::<$ti>()) {
92+
if $cond {
93+
run_kernel_test::<$ti, $ker>(
94+
&x,
95+
&[Load(RegisterId::B, konst), Move(RegisterId::A, RegisterId::B)],
96+
|_| konst
97+
);
98+
}
99+
}
100+
59101
#[test]
60102
fn add_const_prop(alpha in any::<$ti>(), x in x_strat()) {
61103
if $cond {
@@ -122,6 +164,28 @@ macro_rules! act_tests {
122164
);
123165
}
124166
}
167+
168+
#[test]
169+
fn hard_sigmoid(x in x_strat(), alpha in any::<$ti>(), beta in any::<$ti>()) {
170+
if $cond {
171+
run_kernel_test::<$ti, $ker>(
172+
&x,
173+
&$crate::frame::activations::definitions::hard_sigmoid(alpha, beta).ops,
174+
|x| (x * alpha + beta).min(<$ti>::one()).max(<$ti>::zero())
175+
);
176+
}
177+
}
178+
179+
#[test]
180+
fn hard_swish(x in x_strat()) {
181+
if $cond {
182+
run_kernel_test::<$ti, $ker>(
183+
&x,
184+
&$crate::frame::activations::definitions::hard_swish().ops,
185+
|x| (x * 1./6. + 0.5).min(<$ti>::one()).max(<$ti>::zero()) * x
186+
);
187+
}
188+
}
125189
}
126190
/*
127191
prop_act_e2e!($cond, $ti, $ker, affine(alpha, beta));

0 commit comments

Comments
 (0)