Skip to content

Commit 749016c

Browse files
committed
128x3 for q40f32 avx512 kit
1 parent 3fb1893 commit 749016c

File tree

5 files changed

+165
-8
lines changed

5 files changed

+165
-8
lines changed

linalg/src/frame/mmm/tests/fuse.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ where
172172
let v = c.to_vec();
173173
let c = mmm_stride_storage(&v, ker.nr());
174174
let mut ops = ops.to_vec();
175-
ops.insert(0, FusedKerSpec::AddUnicast(c));
175+
ops.insert(0, FusedKerSpec::AddUnicast(c)); // FIXME
176176
ops.insert(0, FusedKerSpec::Clear);
177177
ops.push(FusedKerSpec::Store(c));
178178
ops.push(FusedKerSpec::Done);

linalg/src/x86_64_fma/mmm.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ MMMExternKernel!(fma_mmm_f32_32x3<f32>(32,3)@(32,4) where(FMA)
3939
MMMExternKernel!(avx512_mmm_f32_128x1<f32>(128, 1)@(64,4) where (AVX512F)
4040
packing[1] = q40f32 => |k| k.with_packing_a(pq40_r128());
4141
);
42+
MMMExternKernel!(avx512_mmm_f32_128x3<f32>(128, 3)@(64,4) where (AVX512F));
43+
4244
MMMExternKernel!(avx512_mmm_f32_16x1 <f32>( 16, 1)@(64,4) where (AVX512F));
4345
MMMExternKernel!(avx512_mmm_f32_16x12<f32>( 16,12)@(64,4) where (AVX512F));
4446
MMMExternKernel!(avx512_mmm_f32_16x8 <f32>( 16, 8)@(64,4) where (AVX512F));
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
{% comment %}
2+
// vim: set syntax=asm :
3+
4+
/* mmm 128 x 3:
5+
6+
zmm0 zmm8 zmm816
7+
zmm1 zmm9 zmm17
8+
zmm2 zmm10 zmm18
9+
zmm3 zmm11 zmm19
10+
zmm4 zmm12 zmm20
11+
zmm5 zmm13 zmm21
12+
zmm6 zmm14 zmm22
13+
zmm7 zmm15 zmm23
14+
15+
16+
System V ABI:
17+
args: rdi, rsi, rdx, rcx, r8, r9
18+
preserve: rbx, rsp, rbp, r12, r13, r14, r15
19+
scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11
20+
return: rax (+rdx)
21+
22+
Windows ABI:
23+
args: RCX, RDX, R8, R9
24+
preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15
25+
scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15
26+
return: rax (+rdx)
27+
*/
28+
{% endcomment %}
29+
30+
{% include "preamble.tmpliq" size:"128x3", suffix:suffix, G:G, arch:"avx512" %}
31+
32+
{{L}}clear:
33+
vzeroall
34+
{% for i in (16..23) %}
35+
vmovapd zmm{{i}}, zmm0
36+
{% endfor %}
37+
jmp {{L}}non_linear_loop
38+
39+
{{L}}add_mat_mul:
40+
mov rbx, [rdi + 24] // B
41+
mov rax, [rdi + 16] // A
42+
43+
mov rcx, [rdi + 8] // k
44+
test rcx, rcx
45+
jz {{L}}non_linear_loop
46+
47+
{{L}}main_loop_packed_packed:
48+
vbroadcastss zmm29, dword ptr [rbx]
49+
vbroadcastss zmm30, dword ptr [rbx+4]
50+
vbroadcastss zmm31, dword ptr [rbx+8]
51+
52+
{% for i in (0..7) %}
53+
vmovups zmm28, zmmword ptr [rax+{{i | times:64}}]
54+
vfmadd231ps zmm{{i}}, zmm28, zmm29
55+
vfmadd231ps zmm{{i | plus: 8}}, zmm28, zmm30
56+
vfmadd231ps zmm{{i | plus: 16}}, zmm28, zmm31
57+
{% endfor %}
58+
59+
add rbx, 12
60+
add rax, 512
61+
62+
dec rcx
63+
jnz {{L}}main_loop_packed_packed
64+
65+
jmp {{L}}non_linear_loop
66+
67+
{% include "f32_scalars.tmpliq" from:0, to:23 %}
68+
{% include "f32_per_rows.tmpliq" mr:128, from:0, to:23 %}
69+
{% include "f32_per_cols.tmpliq" mr:128, from:0, to:23 %}
70+
{% include "avx512_mmm_load_tile.tmpliq" from:0, to:23 %}
71+
72+
{{L}}range_0_16:
73+
{% for i in (0..15) %}
74+
{{long}} {{i}}
75+
{% endfor %}
76+
77+
{{L}}add_unicast:
78+
79+
mov r10, [rdi + 8] // c ptr
80+
mov rsi, [rdi + 16] // row stride
81+
mov rbx, [rdi + 24] // col stride
82+
83+
vbroadcastss zmm29, dword ptr [rdi+16] // row stride (aka esi)
84+
vmovups zmm26, [{{offset}} {{L}}range_0_16]
85+
vpmulld zmm26, zmm26, zmm29
86+
87+
{% for i in (0..2) %}
88+
kxnorw k1,k1,k1
89+
vgatherdps zmm24{k1}, [ r10 + zmm26 ]
90+
add r10, rbx
91+
vaddps zmm{{i | times: 8}}, zmm{{i | times: 8}}, zmm24
92+
{% endfor %}
93+
94+
imul esi, 16
95+
vpbroadcastd zmm27, esi
96+
97+
{% for j in (1..7) %}
98+
mov r10, [rdi + 8]
99+
vpaddd zmm26, zmm26, zmm27
100+
101+
{% for i in (0..2) %}
102+
kxnorw k1,k1,k1
103+
vgatherdps zmm24{k1}, [ r10 + zmm26 ]
104+
add r10, rbx
105+
vaddps zmm{{i | times: 8 | plus: j}}, zmm{{i | times: 8 | plus: j}}, zmm24
106+
{% endfor %}
107+
{% endfor %}
108+
109+
jmp {{L}}non_linear_loop
110+
111+
{{L}}add_row_col_products:
112+
mov rax, [ rdi + 8 ]
113+
mov rbx, [ rdi + 16 ]
114+
115+
vbroadcastss zmm29, dword ptr [rbx]
116+
vbroadcastss zmm30, dword ptr [rbx+4]
117+
vbroadcastss zmm31, dword ptr [rbx+8]
118+
119+
{% for i in (0..7) %}
120+
vmovups zmm28, zmmword ptr [rax+{{i | times:64}}]
121+
vfmadd231ps zmm{{i}}, zmm28, zmm29
122+
vfmadd231ps zmm{{i | plus: 8}}, zmm28, zmm30
123+
vfmadd231ps zmm{{i | plus: 16}}, zmm28, zmm31
124+
{% endfor %}
125+
126+
jmp {{L}}non_linear_loop
127+
128+
{{L}}store:
129+
mov r8, [rdi + 8] // c ptr
130+
mov rsi, [rdi + 16] // row stride
131+
mov rbx, [rdi + 24] // col stride
132+
133+
// tops of cols
134+
lea r9, [ r8 + rbx ]
135+
lea r10, [ r8 + 2 * rbx ]
136+
lea r11, [ r10 + rbx ]
137+
138+
{% for word in (0..7) %}
139+
{% for quarter in (0..3) %}
140+
{% for r in (0..2) %}
141+
vextractf32x4 xmm{{r | plus: 24}}, zmm{{r | times: 8 | plus: word}}, {{quarter}}
142+
{% endfor %}
143+
{% for row in (0..3) %}
144+
{% for i in (0..2) %}
145+
vextractps dword ptr [r{{i | plus: 8}}], xmm{{i | plus: 24}}, {{row}}
146+
add r{{i | plus: 8}}, rsi
147+
{% endfor %}
148+
{% endfor %}
149+
{% endfor %}
150+
{% endfor %}
151+
152+
jmp {{L}}non_linear_loop
153+
154+
{% include "postamble.tmpliq" size:"128x3", suffix:suffix, G:G, L:L, arch:"avx512" %}
155+

linalg/x86_64/avx512/f32_scalars.tmpliq

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
{{L}}leaky_relu:
1111
// can only use zmm12 to zmm15
1212
// ymm15 <- alpha
13-
vbroadcastss zmm15, dword ptr [rdi + 8]
13+
vbroadcastss zmm31, dword ptr [rdi + 8]
1414
// ymm14 <- all zero
15-
vpxorq zmm14, zmm14, zmm14
15+
vpxorq zmm30, zmm30, zmm30
1616

1717
{% for reg in (from..to) %}
18-
vcmpps k1, zmm{{reg}}, zmm14, 1 // 1 means LT
18+
vcmpps k1, zmm{{reg}}, zmm30, 1 // 1 means LT
1919
// ymm12 <- alpha * x if < 0
20-
vmulps zmm{{reg}} {k1}, zmm{{reg}}, zmm15
20+
vmulps zmm{{reg}} {k1}, zmm{{reg}}, zmm31
2121
{% endfor %}
2222
// select muled of orginal
2323

linalg/x86_64/avx512/zmm_scalar.tmpliq

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
// vim: set syntax=asm :
22

33
{{L}}{{label}}:
4-
vbroadcastss zmm12, dword ptr [rdi + 8]
4+
vbroadcastss zmm31, dword ptr [rdi + 8]
55
{% if flipped %}
66
{% for reg in (from..to) %}
7-
{{op}} zmm{{reg}}, zmm{{reg}}, zmm12
7+
{{op}} zmm{{reg}}, zmm{{reg}}, zmm31
88
{% endfor %}
99
{% else %}
1010
{% for reg in (from..to) %}
11-
{{op}} zmm{{reg}}, zmm12, zmm{{reg}}
11+
{{op}} zmm{{reg}}, zmm31, zmm{{reg}}
1212
{% endfor %}
1313
{% endif %}
1414

0 commit comments

Comments
 (0)