Skip to content

Commit de1829d

Browse files
committed
sigmoid
1 parent be9d8cc commit de1829d

File tree

4 files changed

+222
-136
lines changed

4 files changed

+222
-136
lines changed

linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,41 @@
339339
fmax v7.4s, v7.4s, v24.4s
340340
b .inner_loop
341341

342-
.fma:
343-
b .unsupported
342+
.fma:
343+
// a <- a * b + k
344+
// vfma a,b,c does a <- a + b * c
345+
// mov d,a ; mov a,#k ; vfma a, b, d
346+
347+
and v24.16b, v0.16b, v0.16b
348+
and v25.16b, v1.16b, v1.16b
349+
and v26.16b, v2.16b, v2.16b
350+
and v27.16b, v3.16b, v3.16b
351+
and v28.16b, v4.16b, v4.16b
352+
and v29.16b, v5.16b, v5.16b
353+
and v30.16b, v6.16b, v6.16b
354+
and v31.16b, v7.16b, v7.16b
355+
356+
ins v0.s[0], w3
357+
add x5, x5, 4
358+
dup v0.4s, v0.s[0]
359+
dup v1.4s, v0.s[0]
360+
dup v2.4s, v0.s[0]
361+
dup v3.4s, v0.s[0]
362+
dup v4.4s, v0.s[0]
363+
dup v5.4s, v0.s[0]
364+
dup v6.4s, v0.s[0]
365+
dup v7.4s, v0.s[0]
366+
367+
fmla v0.4s, v24.4s, v8.4s
368+
fmla v1.4s, v25.4s, v9.4s
369+
fmla v2.4s, v26.4s, v10.4s
370+
fmla v3.4s, v27.4s, v11.4s
371+
fmla v4.4s, v28.4s, v12.4s
372+
fmla v5.4s, v29.4s, v13.4s
373+
fmla v6.4s, v30.4s, v14.4s
374+
fmla v7.4s, v31.4s, v15.4s
375+
376+
b .inner_loop
344377

345378
.if_pos_then_else:
346379
fcmge v0.4s, v0.4s, #0.0
@@ -362,7 +395,34 @@
362395
b .inner_loop
363396

364397
.swap_b_c:
365-
b .unsupported
398+
// move d <- b
399+
and v24.16b, v8.16b , v8.16b
400+
and v25.16b, v9.16b , v9.16b
401+
and v26.16b, v10.16b, v10.16b
402+
and v27.16b, v11.16b, v11.16b
403+
and v28.16b, v12.16b, v12.16b
404+
and v29.16b, v13.16b, v13.16b
405+
and v30.16b, v14.16b, v14.16b
406+
and v31.16b, v15.16b, v15.16b
407+
// move b <- c
408+
and v8.16b , v16.16b, v16.16b
409+
and v9.16b , v17.16b, v17.16b
410+
and v10.16b, v18.16b, v18.16b
411+
and v11.16b, v19.16b, v19.16b
412+
and v12.16b, v20.16b, v20.16b
413+
and v13.16b, v21.16b, v21.16b
414+
and v14.16b, v22.16b, v22.16b
415+
and v15.16b, v23.16b, v23.16b
416+
// move c <- d
417+
and v16.16b, v24.16b, v24.16b
418+
and v17.16b, v25.16b, v25.16b
419+
and v18.16b, v26.16b, v26.16b
420+
and v19.16b, v27.16b, v27.16b
421+
and v20.16b, v28.16b, v28.16b
422+
and v21.16b, v29.16b, v29.16b
423+
and v22.16b, v30.16b, v30.16b
424+
and v23.16b, v31.16b, v31.16b
425+
b .inner_loop
366426

367427
.floor:
368428
b .unsupported

linalg/benches/activations.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
22
use tract_linalg::frame::activations::{definitions, reference, ActivationKer, Program};
33

4+
const SIZES:&[i32] = &[32, 256, 1024, 8192];
5+
46
fn crit(c: &mut Criterion, name: &str, r: impl Fn(f32) -> f32, prog: &Program<f32>) {
57
let mut group = c.benchmark_group(name);
6-
for size in [1i32, 32, 256, 1024, 8192].iter() {
8+
for size in SIZES {
79
group.throughput(criterion::Throughput::Elements(*size as u64));
810
group.bench_with_input(BenchmarkId::new("Reference", size), size, |b, size| {
911
b.iter_batched(
@@ -14,7 +16,7 @@ fn crit(c: &mut Criterion, name: &str, r: impl Fn(f32) -> f32, prog: &Program<f3
1416
}
1517
},
1618
BatchSize::LargeInput,
17-
)
19+
)
1820
});
1921
#[allow(unused_mut)]
2022
let mut vms = vec!(tract_linalg::generic::activations::SActivations::act());
@@ -29,7 +31,17 @@ fn crit(c: &mut Criterion, name: &str, r: impl Fn(f32) -> f32, prog: &Program<f3
2931
|| vec![1.0f32; *size as usize],
3032
|mut v| vm.run(prog, &mut v),
3133
BatchSize::LargeInput,
32-
)
34+
)
35+
});
36+
}
37+
if name == "sigmoid" {
38+
let sigmoid = (tract_linalg::ops().sigmoid_f32)();
39+
group.bench_with_input(BenchmarkId::new("handcrafted", size), size, |b, size| {
40+
b.iter_batched(
41+
|| vec![1.0f32; *size as usize],
42+
|mut v| sigmoid.run(&mut v),
43+
BatchSize::LargeInput,
44+
)
3345
});
3446
}
3547
}
@@ -38,11 +50,12 @@ fn crit(c: &mut Criterion, name: &str, r: impl Fn(f32) -> f32, prog: &Program<f3
3850
fn criterion_benchmark(c: &mut Criterion) {
3951
crit(c, "relu", reference::relu, &definitions::relu());
4052
crit(c, "hardswish", reference::hardswish, &definitions::hard_swish());
41-
/*
42-
crit(c, "exp2f", reference::exp2f, &definitions::exp2f());
4353
crit(c, "sigmoid", reference::sigmoid, &definitions::sigmoid());
44-
*/
45-
}
4654

47-
criterion_group!(benches, criterion_benchmark);
48-
criterion_main!(benches);
55+
/*
56+
crit(c, "exp2f", reference::exp2f, &definitions::exp2f());
57+
*/
58+
}
59+
60+
criterion_group!(benches, criterion_benchmark);
61+
criterion_main!(benches);

0 commit comments

Comments
 (0)