Skip to content

Commit 08e0c52

Browse files
committed
wip, broken max const
1 parent 79db331 commit 08e0c52

File tree

3 files changed

+73
-9
lines changed

3 files changed

+73
-9
lines changed

linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
.text
99
.align 4
1010

11+
// fn(ops: *const Op, constants: *const $ti, xs: *mut $ti, len: usize) -> usize
12+
13+
// x0 <- ops, x1 <- constant, x2 <- xs, x3 <- len(xs)
14+
1115
.cpu generic+fp+simd
1216
.global {{G}}arm64simd_act_f32_32n_{{suffix}}
1317
{{G}}arm64simd_act_f32_32n_{{suffix}}:
@@ -16,9 +20,61 @@
1620
stp d10, d11, [sp, #-16]!
1721
stp d12, d13, [sp, #-16]!
1822
stp d14, d15, [sp, #-16]!
23+
24+
cmp x3, 0
25+
beq .ok
26+
27+
.outer_loop:
28+
mov x5, x0 // x5 is "pc"
29+
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2]
30+
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], 64
31+
32+
.inner_loop:
33+
ldr x6, [x5] // x6 is fetched instruction at x5
34+
and x7, x6, 0xffff
35+
36+
cmp x7, 0
37+
beq .end_of_inner_loop
38+
cmp x7, 10
39+
beq .max_const
40+
41+
b .unsupported
42+
43+
.inner_loop_payload_done:
44+
add x5, x5, 4
45+
b .inner_loop
46+
.end_of_inner_loop:
47+
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2]
48+
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], 64
49+
50+
add x2, x2, 128
51+
subs x3, x3, 32
52+
bne .outer_loop
53+
54+
.max_const:
55+
lsr x7, x6, 16
56+
and x7, x7, 0xff
57+
lsl x7, x7, 2
58+
add x7, x7, x1
59+
ld1 { v24.s }[0], [x7]
60+
dup v24.4s, v24.s[0]
61+
fmax v0.4s, v0.4s, v24.4s
62+
fmax v1.4s, v1.4s, v24.4s
63+
fmax v2.4s, v2.4s, v24.4s
64+
fmax v3.4s, v3.4s, v24.4s
65+
fmax v4.4s, v4.4s, v24.4s
66+
fmax v5.4s, v5.4s, v24.4s
67+
fmax v6.4s, v6.4s, v24.4s
68+
fmax v7.4s, v7.4s, v24.4s
69+
b .inner_loop_payload_done
70+
71+
72+
.unsupported:
73+
mov x0, 1
74+
b .return
1975

76+
.ok:
2077
mov x0, 0
21-
// b .return
2278

2379
.return:
2480
ldp d14, d15, [sp], #16

linalg/src/frame/activations.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,12 @@ macro_rules! act_impl {
133133
}
134134
}
135135

136-
#[cfg(test)]
137-
act_tests!($cond, $func, $ti);
136+
mod [<test_ $func>] {
137+
use super::*;
138+
139+
#[cfg(test)]
140+
act_tests!($cond, $func, $ti);
141+
}
138142
}
139143
};
140144
}

linalg/src/frame/activations/tests.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
use crate::LADatum;
22

3-
use super::{Program, Op};
3+
use super::{Op, Program};
44
use Op::*;
55

66
pub fn noop<T: LADatum>() -> Program<T> {
77
Program { ops: vec![Done], csts: vec![] }
88
}
99

10+
pub fn max_const<T: LADatum>(c: T) -> Program<T> {
11+
Program { ops: vec![MaxConst(0)], csts: vec![c] }
12+
}
13+
1014
macro_rules! prop_act_e2e {
1115
($cond:expr, $ti: ty, $ker: ty, $name: ident ( $($param:ident),* )) => {
1216
proptest::proptest! {
1317
#[test]
1418
fn $name(
1519
x in proptest::prelude::any::<$ti>(),
1620
repeat in 1usize..4,
17-
$($param in proptest::prelude::any::<$ti>()),*)
21+
$($param in proptest::prelude::any::<$ti>()),*)
1822
{
1923
use crate::frame::activations::ActivationKer;
2024
if $cond {
@@ -39,14 +43,14 @@ macro_rules! prop_act_unit {
3943
fn $name(
4044
x in proptest::prelude::any::<$ti>(),
4145
repeat in 1usize..4,
42-
$($param in proptest::prelude::any::<$ti>()),*)
46+
$($param in proptest::prelude::any::<$ti>()),*)
4347
{
4448
use crate::frame::activations::ActivationKer;
4549
if $cond {
4650
let mut input = tract_data::prelude::Tensor::zero_aligned::<$ti>(&[<$ker>::nr() * repeat], <$ker>::alignment_bytes()).unwrap();
4751
input.fill_t::<$ti>(x).unwrap();
48-
let refer2: fn($ti) -> $ti = $refer;
49-
let expected:Vec<$ti> = input.as_slice::<$ti>().unwrap().iter().cloned().map(refer2).collect();
52+
// let refer2: fn($ti, $($param),*) -> $ti = $refer;
53+
let expected:Vec<$ti> = input.as_slice::<$ti>().unwrap().iter().cloned().map(|x| $refer(x, $($param),*)).collect();
5054
let prog = crate::frame::activations::tests::$name($($param),*);
5155
<$ker>::run(&prog.ops, &prog.csts, &mut input.as_slice_mut::<$ti>().unwrap());
5256

@@ -62,6 +66,7 @@ macro_rules! prop_act_unit {
6266
macro_rules! act_tests {
6367
($cond:expr, $ker:ty, $ti:ty) => {
6468
prop_act_unit!($cond, $ti, $ker, noop(), |x| x);
69+
prop_act_unit!($cond, $ti, $ker, max_const(alpha), |x: $ti, alpha| x.max(alpha));
6570

6671
prop_act_e2e!($cond, $ti, $ker, relu());
6772
prop_act_e2e!($cond, $ti, $ker, affine(alpha, beta));
@@ -75,4 +80,3 @@ macro_rules! act_tests {
7580
*/
7681
};
7782
}
78-

0 commit comments

Comments
 (0)