Skip to content

Commit dc9148b

Browse files
committed
avx512 panel extraction
1 parent 749016c commit dc9148b

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed

linalg/src/x86_64_fma/panel_extract.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ panel_extractor!(kernel_packed_32_f16_to_f32 as packed_32_f16_to_f32(
1717
PackedFormat::new(f32::datum_type(), 32, 32)
1818
) where(AVX2));
1919

20+
panel_extractor!(kernel_packed_128_q40_to_f32::kernel as packed_128_q40_to_f32(
21+
Box::new(super::mmm::PQ40_R128),
22+
PackedFormat::new(f32::datum_type(), 128, 32)
23+
) where(AVX512F));
24+
25+
mod kernel_packed_128_q40_to_f32 {
26+
extern_kernel!(fn avx512_packed_128_q40_to_f32(i: *const u8, output: *mut u8, k: usize) -> ());
27+
pub unsafe fn kernel(input: *const u8, output: *mut u8, k: usize) {
28+
avx512_packed_128_q40_to_f32(input, output, k)
29+
}
30+
}
31+
2032
#[target_feature(enable = "avx2")]
2133
unsafe fn kernel_packed_32_q40_to_f32(input: *const u8, output: *mut u8, k: usize) {
2234
debug_assert!(k % 32 == 0);
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
{% comment %}
2+
// vim: set syntax=asm :
3+
4+
/*
5+
System V ABI:
6+
args: rdi, rsi, rdx, rcx, r8, r9
7+
preserve: rbx, rsp, rbp, r12, r13, r14, r15
8+
scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11
9+
return: rax (+rdx)
10+
11+
Windows ABI:
12+
args: RCX, RDX, R8, R9
13+
preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15
14+
scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of ZMM0-15 and ZMM0-15
15+
return: rax (+rdx)
16+
*/
17+
{% endcomment %}
18+
{% if msvc %}
19+
20+
_text segment
21+
avx512_packed_128_q40_to_f32_{{suffix}} proc
22+
23+
{% else %}
24+
25+
.intel_syntax noprefix
26+
.text
27+
.p2align 5
28+
.globl {{G}}avx512_packed_128_q40_to_f32_{{suffix}}
29+
{{G}}avx512_packed_128_q40_to_f32_{{suffix}}:
30+
.cfi_startproc
31+
32+
{% endif %}
33+
34+
push rbp
35+
mov rbp, rsp
36+
37+
{% if family == "windows" %}
38+
// https://www.agner.org/optimize/calling_conventions.pdf xmm6-15 are not scratch
39+
// https://stackoverflow.com/questions/43358429/save-value-of-xmm-registers
40+
and rsp,-16
41+
lea rsp,[rsp-160]
42+
vmovaps [rsp], xmm6
43+
vmovaps [rsp+16*1],xmm7
44+
vmovaps [rsp+16*2],xmm8
45+
vmovaps [rsp+16*3],xmm9
46+
vmovaps [rsp+16*4],xmm10
47+
vmovaps [rsp+16*5],xmm11
48+
vmovaps [rsp+16*6],xmm12
49+
vmovaps [rsp+16*7],xmm13
50+
vmovaps [rsp+16*8],xmm14
51+
vmovaps [rsp+16*9],xmm15
52+
53+
// FIXME calling_conventions
54+
push rdi
55+
push rsi
56+
57+
mov rdi, rcx
58+
59+
{% endif %}
60+
61+
push rbx
62+
push r12
63+
push r13
64+
push r14
65+
push r15
66+
67+
sub rsp, 8
68+
{% if family == "unix" %}
69+
.cfi_def_cfa_offset 64
70+
{% endif %}
71+
stmxcsr [rsp + 4]
72+
{% if msvc %}
73+
mov rax, 1FC0h
74+
{% else %}
75+
mov rax, 0x1FC0
76+
{% endif %}
77+
mov [rsp], eax
78+
ldmxcsr [rsp]
79+
80+
// unix: rdi:input rsi: output, rdx:k
81+
82+
{{L}}q40f32:
83+
// zmm0-7: acc
84+
// zmm8-16: scales
85+
// zmm30: 8
86+
// zmm29: mask
87+
// zmm31: b value
88+
vbroadcastss zmm29, dword ptr [{{offset}} {{L}}q40f32_mask]
89+
vbroadcastss zmm30, dword ptr [{{offset}} {{L}}q40f32_eight]
90+
vmovups zmm28, [{{offset}} {{L}}q40f32_perm]
91+
92+
{{L}}q40f32_outerloop:
93+
// scales
94+
{% for i in (0..7) %}
95+
vmovaps ymm{{i|plus:8}}, [rdi + {{i|times:32}}]
96+
{% endfor %}
97+
{% for i in (0..7) %}
98+
vcvtph2ps zmm{{i|plus:8}}, ymm{{i|plus:8}}
99+
{% endfor %}
100+
add rdi, 256
101+
mov rax, 32
102+
103+
{{L}}q40f32_innerloop:
104+
vmovaps zmm27, [rdi] // 128 nibbles
105+
106+
vpandq zmm26, zmm27, zmm29 // 64 bytes
107+
108+
vpmovzxbd zmm16, xmm26 // 16 u32
109+
vpermt2q zmm26, zmm28, zmm26
110+
vpmovzxbd zmm17, xmm26 // 16 u32
111+
vpermt2q zmm26, zmm28, zmm26
112+
vpmovzxbd zmm18, xmm26 // 16 u32
113+
vpermt2q zmm26, zmm28, zmm26
114+
vpmovzxbd zmm19, xmm26 // 16 u32
115+
116+
vpsrlw zmm27, zmm27, 4
117+
vpandq zmm26, zmm27, zmm29 // 64 bytes
118+
119+
vpmovzxbd zmm20, xmm26 // 16 u32
120+
vpermt2q zmm26, zmm28, zmm26
121+
vpmovzxbd zmm21, xmm26 // 16 u32
122+
vpermt2q zmm26, zmm28, zmm26
123+
vpmovzxbd zmm22, xmm26 // 16 u32
124+
vpermt2q zmm26, zmm28, zmm26
125+
vpmovzxbd zmm23, xmm26 // 16 u32
126+
127+
128+
{% for i in (16..23) %}
129+
vpsubd zmm{{i}}, zmm{{i}}, zmm30
130+
{% endfor %}
131+
132+
{% for i in (16..23) %}
133+
vcvtdq2ps zmm{{i}}, zmm{{i}}
134+
{% endfor %}
135+
136+
{% for i in (0..7) %}
137+
vmulps zmm{{i|plus:16}}, zmm{{i|plus:16}}, zmm{{i|plus:8}}
138+
{% endfor %}
139+
140+
{% for i in (0..7) %}
141+
vmovaps [rsi + {{i|times:64}}], zmm{{i|plus:16}}
142+
{% endfor %}
143+
144+
add rdi, 64
145+
add rsi, 512
146+
sub rax, 1
147+
jnz {{L}}q40f32_innerloop
148+
149+
sub rdx, 32
150+
jnz {{L}}q40f32_outerloop
151+
152+
{{L}}return:
153+
ldmxcsr [rsp + 4]
154+
add rsp, 8
155+
156+
pop r15
157+
pop r14
158+
pop r13
159+
pop r12
160+
pop rbx
161+
162+
{% if family == "windows" %}
163+
pop rsi
164+
pop rdi
165+
166+
vmovaps xmm15, [rsp+16*9]
167+
vmovaps xmm14, [rsp+16*8]
168+
vmovaps xmm13, [rsp+16*7]
169+
vmovaps xmm12, [rsp+16*6]
170+
vmovaps xmm11, [rsp+16*5]
171+
vmovaps xmm10, [rsp+16*4]
172+
vmovaps xmm9, [rsp+16*3]
173+
vmovaps xmm8, [rsp+16*2]
174+
vmovaps xmm7, [rsp+16*1]
175+
vmovaps xmm6, [rsp]
176+
{% endif %}
177+
178+
mov rsp, rbp
179+
pop rbp
180+
ret
181+
182+
{{L}}q40f32_mask:
183+
{% if msvc %}
184+
{{long}} 0F0F0F0Fh
185+
{% else %}
186+
{{long}} 0x0F0F0F0F
187+
{% endif %}
188+
189+
{{L}}q40f32_eight:
190+
{{long}} 8
191+
192+
{{L}}q40f32_perm:
193+
{{quad}} 2
194+
{{quad}} 3
195+
{{quad}} 4
196+
{{quad}} 5
197+
{{quad}} 6
198+
{{quad}} 7
199+
{{quad}} 0 // we dont care what's rolling in from the right
200+
{{quad}} 0
201+
202+
203+
{% if msvc %}
204+
avx512_packed_128_q40_to_f32_{{suffix}} endp
205+
_text ends
206+
end
207+
208+
{% else %}
209+
.cfi_endproc
210+
{% endif %}

0 commit comments

Comments
 (0)