Skip to content

Commit b1fe5a8

Browse files
committed
Merge branch 'jsd-beta' of github.com:Tcc0403/Liger-Kernel into jsd-beta
2 parents b07be0a + 1817598 commit b1fe5a8

File tree

6 files changed

+546
-50
lines changed

6 files changed

+546
-50
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,10 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
272272
| **Kernel** | **API** |
273273
|---------------------------------|-------------------------------------------------------------|
274274
| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
275-
275+
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
276276

277277
- **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
278-
278+
- **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
279279
<!-- TODO: be more specific about batch size -->
280280
> **Note:**
281281
> Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor:
8+
values_per_item = 8 // bits
9+
packed_shape = packed.shape
10+
11+
if len(packed_shape) == 1:
12+
original_row_dim = packed_shape[0] * values_per_item
13+
unpacked_shape = (original_row_dim,)
14+
else:
15+
original_row_dim = packed_shape[0] * values_per_item
16+
unpacked_shape = (original_row_dim, *packed_shape[1:])
17+
18+
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
19+
20+
for i in range(values_per_item):
21+
start = i * packed_shape[0]
22+
end = start + packed_shape[0]
23+
mask = 3 << (2 * i)
24+
unpacked[start:end] = (packed & mask) >> (2 * i)
25+
26+
unpacked = unpacked.to(torch.int32) - 1
27+
return unpacked
28+
29+
30+
def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
31+
intweights += 1
32+
original_shape = intweights.shape
33+
values_per_item = 8 // bits
34+
row_dim = (original_shape[0] + values_per_item - 1) // values_per_item
35+
36+
if len(original_shape) == 1:
37+
packed_tensor_shape = (row_dim,)
38+
else:
39+
packed_tensor_shape = (row_dim, *original_shape[1:])
40+
41+
packed = torch.zeros(
42+
packed_tensor_shape, device=intweights.device, dtype=torch.uint8
43+
)
44+
unpacked = intweights.to(torch.uint8)
45+
46+
def lshift(t: torch.Tensor, bits: int):
47+
return t << bits
48+
49+
it = min(values_per_item, (original_shape[0] // row_dim) + 1)
50+
for i in range(it):
51+
start = i * row_dim
52+
end = min(start + row_dim, original_shape[0])
53+
packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)
54+
55+
return packed
56+
57+
58+
def get_autotune_config():
59+
return [
60+
triton.Config(
61+
{
62+
"BLOCK_SIZE_M": 128,
63+
"BLOCK_SIZE_N": 256,
64+
"BLOCK_SIZE_K": 64,
65+
"GROUP_SIZE_M": 8,
66+
},
67+
num_stages=3,
68+
num_warps=8,
69+
),
70+
triton.Config(
71+
{
72+
"BLOCK_SIZE_M": 64,
73+
"BLOCK_SIZE_N": 256,
74+
"BLOCK_SIZE_K": 32,
75+
"GROUP_SIZE_M": 8,
76+
},
77+
num_stages=4,
78+
num_warps=4,
79+
),
80+
triton.Config(
81+
{
82+
"BLOCK_SIZE_M": 128,
83+
"BLOCK_SIZE_N": 128,
84+
"BLOCK_SIZE_K": 32,
85+
"GROUP_SIZE_M": 8,
86+
},
87+
num_stages=4,
88+
num_warps=4,
89+
),
90+
triton.Config(
91+
{
92+
"BLOCK_SIZE_M": 128,
93+
"BLOCK_SIZE_N": 64,
94+
"BLOCK_SIZE_K": 32,
95+
"GROUP_SIZE_M": 8,
96+
},
97+
num_stages=4,
98+
num_warps=4,
99+
),
100+
triton.Config(
101+
{
102+
"BLOCK_SIZE_M": 64,
103+
"BLOCK_SIZE_N": 128,
104+
"BLOCK_SIZE_K": 32,
105+
"GROUP_SIZE_M": 8,
106+
},
107+
num_stages=4,
108+
num_warps=4,
109+
),
110+
triton.Config(
111+
{
112+
"BLOCK_SIZE_M": 128,
113+
"BLOCK_SIZE_N": 32,
114+
"BLOCK_SIZE_K": 32,
115+
"GROUP_SIZE_M": 8,
116+
},
117+
num_stages=4,
118+
num_warps=4,
119+
),
120+
triton.Config(
121+
{
122+
"BLOCK_SIZE_M": 128,
123+
"BLOCK_SIZE_N": 256,
124+
"BLOCK_SIZE_K": 128,
125+
"GROUP_SIZE_M": 8,
126+
},
127+
num_stages=3,
128+
num_warps=8,
129+
),
130+
triton.Config(
131+
{
132+
"BLOCK_SIZE_M": 256,
133+
"BLOCK_SIZE_N": 128,
134+
"BLOCK_SIZE_K": 128,
135+
"GROUP_SIZE_M": 8,
136+
},
137+
num_stages=3,
138+
num_warps=8,
139+
),
140+
triton.Config(
141+
{
142+
"BLOCK_SIZE_M": 256,
143+
"BLOCK_SIZE_N": 64,
144+
"BLOCK_SIZE_K": 128,
145+
"GROUP_SIZE_M": 8,
146+
},
147+
num_stages=4,
148+
num_warps=4,
149+
),
150+
triton.Config(
151+
{
152+
"BLOCK_SIZE_M": 64,
153+
"BLOCK_SIZE_N": 256,
154+
"BLOCK_SIZE_K": 128,
155+
"GROUP_SIZE_M": 8,
156+
},
157+
num_stages=4,
158+
num_warps=4,
159+
),
160+
triton.Config(
161+
{
162+
"BLOCK_SIZE_M": 128,
163+
"BLOCK_SIZE_N": 128,
164+
"BLOCK_SIZE_K": 128,
165+
"GROUP_SIZE_M": 8,
166+
},
167+
num_stages=4,
168+
num_warps=4,
169+
),
170+
triton.Config(
171+
{
172+
"BLOCK_SIZE_M": 128,
173+
"BLOCK_SIZE_N": 64,
174+
"BLOCK_SIZE_K": 64,
175+
"GROUP_SIZE_M": 8,
176+
},
177+
num_stages=4,
178+
num_warps=4,
179+
),
180+
triton.Config(
181+
{
182+
"BLOCK_SIZE_M": 64,
183+
"BLOCK_SIZE_N": 128,
184+
"BLOCK_SIZE_K": 64,
185+
"GROUP_SIZE_M": 8,
186+
},
187+
num_stages=4,
188+
num_warps=4,
189+
),
190+
triton.Config(
191+
{
192+
"BLOCK_SIZE_M": 128,
193+
"BLOCK_SIZE_N": 32,
194+
"BLOCK_SIZE_K": 64,
195+
"GROUP_SIZE_M": 8,
196+
},
197+
num_stages=4,
198+
num_warps=4,
199+
),
200+
triton.Config(
201+
{
202+
"BLOCK_SIZE_M": 32,
203+
"BLOCK_SIZE_N": 32,
204+
"BLOCK_SIZE_K": 32,
205+
"GROUP_SIZE_M": 4,
206+
},
207+
num_stages=4,
208+
num_warps=4,
209+
),
210+
]
211+
212+
213+
@triton.autotune(
214+
configs=get_autotune_config(),
215+
key=["M", "N", "K"],
216+
)
217+
@triton.jit
218+
def matmul_kernel(
219+
a_ptr,
220+
b_ptr,
221+
c_ptr,
222+
M,
223+
N,
224+
K: tl.constexpr,
225+
stride_am,
226+
stride_ak,
227+
stride_bk,
228+
stride_bn,
229+
stride_cm,
230+
stride_cn,
231+
BLOCK_SIZE_M: tl.constexpr,
232+
BLOCK_SIZE_N: tl.constexpr,
233+
BLOCK_SIZE_K: tl.constexpr,
234+
GROUP_SIZE_M: tl.constexpr,
235+
):
236+
# We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned
237+
tl.static_assert(
238+
K % (4 * BLOCK_SIZE_K) == 0,
239+
"K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K",
240+
)
241+
# determine the block id in the 1D grid, pid <=> blockId in cuda
242+
pid = tl.program_id(axis=0)
243+
# number of blocks we would need in the M dimension
244+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
245+
# number of blocks we would need in the N dimension
246+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
247+
# blocks are grouped along the M dimension. num_pid_in_group computes how many blocks are grouped together,
248+
# and group_id calculates the group to which the current block (pid) belongs.
249+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
250+
group_id = pid // num_pid_in_group
251+
252+
# pid of the first block in the group that the current block belongs too
253+
first_pid_m = group_id * GROUP_SIZE_M
254+
255+
# pid_m : pid of the block along the M dimension of the output matrix, and pid_n : pid of the block along the N dimension of the output matrix
256+
# remember that the grid of blocks is 1D, but we calculate pid_m and pid_n to locate the block pid place in the output matrix
257+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
258+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
259+
pid_n = (pid % num_pid_in_group) // group_size_m
260+
261+
# offs_am represent the indices of elements within the block for matrices A with respect to the M dimension
262+
# offs_bn represent the indices of elements within the block for matrices B with respect to the N dimension
263+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
264+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
265+
offs_k = tl.arange(0, BLOCK_SIZE_K)
266+
267+
"""
268+
This part of the code generates pointers to the specific blocks of matrices A and B that the current thread block will process.
269+
270+
As described in the PyTorch documentation, a stride refers to the step size needed to move from one element to the next along a given dimension:
271+
272+
For matrix A: stride_am = A.stride(0) = K (stride along the rows), and stride_ak = A.stride(1) = 1 (stride along the columns).
273+
For matrix B: stride_bk = B.stride(0) = N (stride along the rows), and stride_bn = B.stride(1) = 1 (stride along the columns).
274+
Now, let's break down the pointer generation:
275+
276+
offs_am[:, None] creates a column of shape [BLOCK_SIZE_M, 1], which represents the row indices of matrix A that this block is processing. It is multiplied by K (the number of columns in matrix A) since A is stored in row-major order. So, the element at position (i, j) in A is located at index i*K + j in memory.
277+
offs_k[None, BLOCK_SIZE_K] creates a row vector representing the column indices of the block, i.e., a range from 0 to BLOCK_SIZE_K. This is used to compute the positions of the columns within the block.
278+
When combined, the result has the shape [BLOCK_SIZE_M, BLOCK_SIZE_K], where each entry (i, j) points to the element in matrix A at position (i, j) for the current block.
279+
280+
The same logic is applied to matrix B, but the resulting shape is [BLOCK_SIZE_K, BLOCK_SIZE_N], representing the block of matrix B that the thread block will work on.
281+
"""
282+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
283+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
284+
285+
# An accumulator matrix is initialized with zeros. It stores the intermediate results of the block matrix multiplication.
286+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
287+
"""
288+
We split the loop into two layers. The outer loop runs 4 times, and each iteration focuses on a specific portion of matrix A.
289+
290+
For example, when i = 0, we’re only concerned with the blocks of matrix A that cover the range from 0 to K // (4 * BLOCK_SIZE_K).
291+
Since matrix B is packed, its first dimension is effectively divided by 4. So, while we process the first segment of matrix A,
292+
we still iterate over the entire first dimension of matrix B.
293+
294+
In each of the 4 iterations of the outer loop, we go through the full blocks of matrix B, but what changes is the data we extract.
295+
Matrix B elements contain 4 weights, all packed into an int8 format, and during each iteration of the outer loop,
296+
we extract a different weight by using bitwise shifting operations. This way, we access a unique weight on each pass.
297+
"""
298+
for i in range(4):
299+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
300+
for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K)):
301+
k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j
302+
# load the block of matrix A
303+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
304+
# load the block of matrix B
305+
b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0)
306+
# when i = 0 for example, we only care about the first 2 bits of the elements of the matrix B, so we use the mask 00000011 to mask the other bits
307+
mask = 3 << (2 * i)
308+
# we shift the results after the mask
309+
b = (b_uint8 & mask) >> (2 * i)
310+
# During the packing of the weights, it's easier to pack 0, 1, 2, then -1, 0, 1, so we add 1 to the weight tensor, and we substract it here
311+
tensor_full = tl.full((1,), 1, dtype=tl.int8)
312+
# We accumulate the result of multiplication of the blocks along the K dimension on int32 to avoid any overflows or underflows.
313+
accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)
314+
# we move the pointers, for the a_ptrs we more in a horizontal way along the second dimension -> we use strid_ak=1
315+
# for b_ptrs we move in a vertical way, along the rows -> we use stride_bk=N
316+
a_ptrs += BLOCK_SIZE_K * stride_ak
317+
b_ptrs += BLOCK_SIZE_K * stride_bk
318+
319+
c = accumulator
320+
# These lines compute the offsets into matrix C where the result of this block’s computation should be stored.
321+
# stride_cm = N & stride_cn = 1
322+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
323+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
324+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
325+
# we do a boundary check to ensure only elements within matrix bounds are stored
326+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
327+
tl.store(c_ptrs, c, mask=c_mask)
328+
329+
330+
def matmul(a, b):
331+
assert (
332+
a.shape[1] == b.shape[0] * 4
333+
), "Incompatible dimensions, the weight matrix need to be packed"
334+
assert a.is_contiguous(), "Matrix A must be contiguous"
335+
M, K = a.shape
336+
_, N = b.shape
337+
# c is in int32 to avoid any overflows or underflows
338+
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
339+
grid = lambda META: (
340+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
341+
)
342+
matmul_kernel[grid](
343+
a,
344+
b,
345+
c,
346+
M,
347+
N,
348+
K,
349+
a.stride(0),
350+
a.stride(1),
351+
b.stride(0),
352+
b.stride(1),
353+
c.stride(0),
354+
c.stride(1),
355+
)
356+
return c

0 commit comments

Comments
 (0)