Skip to content

Commit cdf295b

Browse files
alexm-redhatLeiWang1999
authored andcommitted
[Kernel][Core] Add AWQ support to the Marlin kernel (vllm-project#6612)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent 314dda6 commit cdf295b

File tree

20 files changed

+1601
-283
lines changed

20 files changed

+1601
-283
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
172172
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
173173
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
174174
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
175+
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
175176
"csrc/quantization/fp8/fp8_marlin.cu"
176177
"csrc/custom_all_reduce.cu"
177178
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"

csrc/ops.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
8989
int64_t size_k);
9090

9191
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
92-
torch::Tensor& b_scales, torch::Tensor& g_idx,
93-
torch::Tensor& perm, torch::Tensor& workspace,
94-
int64_t num_bits, int64_t size_m, int64_t size_n,
95-
int64_t size_k, bool is_k_full);
92+
torch::Tensor& b_scales, torch::Tensor& b_zeros,
93+
torch::Tensor& g_idx, torch::Tensor& perm,
94+
torch::Tensor& workspace, int64_t num_bits,
95+
int64_t size_m, int64_t size_n, int64_t size_k,
96+
bool is_k_full, bool has_zp);
9697

9798
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
9899
int64_t size_k, int64_t size_n,
99100
int64_t num_bits);
100101

102+
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
103+
int64_t size_n, int64_t num_bits);
104+
101105
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
102106
torch::Tensor& b_scales, torch::Tensor& workspace,
103107
int64_t num_bits, int64_t size_m, int64_t size_n,

csrc/quantization/fp8/fp8_marlin.cu

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
* Adapted from https://github.com/IST-DASLab/marlin
2020
*/
2121

22-
#include "../gptq_marlin/gptq_marlin.cuh"
23-
#include "../gptq_marlin/gptq_marlin_dtypes.cuh"
22+
#include "../gptq_marlin/marlin.cuh"
23+
#include "../gptq_marlin/marlin_dtypes.cuh"
2424

25-
using namespace gptq_marlin;
25+
using namespace marlin;
2626

2727
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
2828
static_assert(std::is_same<scalar_t, half>::value || \
@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
12241224
", size_k = ", size_k);
12251225

12261226
// Verify B
1227-
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
1228-
" is not divisible by tile_size = ", gptq_marlin::tile_size);
1229-
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
1227+
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
1228+
" is not divisible by tile_size = ", marlin::tile_size);
1229+
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
12301230
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
1231-
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
1232-
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
1231+
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
1232+
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
12331233
"b_q_weight.size(1) = ", b_q_weight.size(1),
1234-
" is not divisible by tile_size = ", gptq_marlin::tile_size);
1235-
int actual_size_n =
1236-
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
1234+
" is not divisible by tile_size = ", marlin::tile_size);
1235+
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
12371236
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
12381237
", actual_size_n = ", actual_size_n);
12391238

@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
12741273
num_groups = b_scales.size(0);
12751274

12761275
// Verify workspace size
1277-
TORCH_CHECK(
1278-
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
1279-
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
1280-
int min_workspace_size =
1281-
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
1276+
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
1277+
", is not divisible by min_thread_n = ", marlin::min_thread_n);
1278+
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
12821279
TORCH_CHECK(workspace.numel() >= min_workspace_size,
12831280
"workspace.numel = ", workspace.numel(),
12841281
" is below min_workspace_size = ", min_workspace_size);
@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
12901287
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
12911288
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
12921289
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
1293-
gptq_marlin::max_par);
1290+
marlin::max_par);
12941291
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
12951292
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
12961293
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
12971294
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
12981295
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
12991296
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
1300-
gptq_marlin::max_par);
1297+
marlin::max_par);
13011298
} else {
13021299
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
13031300
}
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
#include "marlin.cuh"
2+
3+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
4+
5+
namespace marlin {
6+
7+
template <int const num_threads, int const num_bits, bool const has_perm>
8+
__global__ void awq_marlin_repack_kernel(
9+
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
10+
int size_k, int size_n) {}
11+
12+
} // namespace marlin
13+
14+
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
15+
int64_t size_k, int64_t size_n,
16+
int64_t num_bits) {
17+
TORCH_CHECK_NOT_IMPLEMENTED(
18+
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
19+
return torch::empty({1, 1});
20+
}
21+
22+
#else
23+
24+
namespace marlin {
25+
26+
template <int const num_threads, int const num_bits>
27+
__global__ void awq_marlin_repack_kernel(
28+
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
29+
int size_k, int size_n) {
30+
constexpr int pack_factor = 32 / num_bits;
31+
32+
int k_tiles = size_k / tile_k_size;
33+
int n_tiles = size_n / tile_n_size;
34+
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
35+
36+
int start_k_tile = blockIdx.x * block_k_tiles;
37+
if (start_k_tile >= k_tiles) {
38+
return;
39+
}
40+
41+
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
42+
43+
// Wait until the next thread tile has been loaded to shared memory.
44+
auto wait_for_stage = [&]() {
45+
// We only have `stages - 2` active fetches since we are double buffering
46+
// and can only issue the next fetch when it is guaranteed that the previous
47+
// shared memory load is fully complete (as it may otherwise be
48+
// overwritten).
49+
cp_async_wait<repack_stages - 2>();
50+
__syncthreads();
51+
};
52+
53+
extern __shared__ int4 sh[];
54+
55+
constexpr int tile_n_ints = tile_n_size / pack_factor;
56+
57+
constexpr int stage_n_threads = tile_n_ints / 4;
58+
constexpr int stage_k_threads = tile_k_size;
59+
constexpr int stage_size = stage_k_threads * stage_n_threads;
60+
61+
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
62+
if (n_tile_id >= n_tiles) {
63+
cp_async_fence();
64+
return;
65+
}
66+
67+
int first_n = n_tile_id * tile_n_size;
68+
int first_n_packed = first_n / pack_factor;
69+
70+
int4* sh_ptr = sh + stage_size * pipe;
71+
72+
if (threadIdx.x < stage_size) {
73+
int k_id = threadIdx.x / stage_n_threads;
74+
int n_id = threadIdx.x % stage_n_threads;
75+
76+
int first_k = k_tile_id * tile_k_size;
77+
78+
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
79+
reinterpret_cast<int4 const*>(
80+
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
81+
first_n_packed + (n_id * 4)])));
82+
}
83+
84+
cp_async_fence();
85+
};
86+
87+
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
88+
if (n_tile_id >= n_tiles) {
89+
return;
90+
}
91+
92+
int warp_id = threadIdx.x / 32;
93+
int th_id = threadIdx.x % 32;
94+
95+
if (warp_id >= 4) {
96+
return;
97+
}
98+
99+
int tc_col = th_id / 4;
100+
int tc_row = (th_id % 4) * 2;
101+
102+
constexpr int tc_offsets[4] = {0, 1, 8, 9};
103+
104+
int cur_n = warp_id * 16 + tc_col;
105+
int cur_n_packed = cur_n / pack_factor;
106+
int cur_n_pos = cur_n % pack_factor;
107+
108+
constexpr int sh_stride = tile_n_ints;
109+
constexpr uint32_t mask = (1 << num_bits) - 1;
110+
111+
int4* sh_stage_ptr = sh + stage_size * pipe;
112+
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
113+
114+
// Undo interleaving
115+
int cur_n_pos_unpacked;
116+
if constexpr (num_bits == 4) {
117+
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
118+
cur_n_pos_unpacked = undo_pack[cur_n_pos];
119+
} else {
120+
constexpr int undo_pack[4] = {0, 2, 1, 3};
121+
cur_n_pos_unpacked = undo_pack[cur_n_pos];
122+
}
123+
124+
uint32_t vals[8];
125+
#pragma unroll
126+
for (int i = 0; i < 4; i++) {
127+
int cur_elem = tc_row + tc_offsets[i];
128+
129+
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
130+
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
131+
sh_stride * cur_elem];
132+
133+
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
134+
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
135+
}
136+
137+
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
138+
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
139+
140+
// Result of:
141+
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
142+
if constexpr (num_bits == 4) {
143+
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
144+
145+
uint32_t res = 0;
146+
#pragma unroll
147+
for (int i = 0; i < 8; i++) {
148+
res |= vals[pack_idx[i]] << (i * 4);
149+
}
150+
151+
out_ptr[out_offset + th_id * 4 + warp_id] = res;
152+
153+
} else {
154+
constexpr int pack_idx[4] = {0, 2, 1, 3};
155+
156+
uint32_t res1 = 0;
157+
uint32_t res2 = 0;
158+
#pragma unroll
159+
for (int i = 0; i < 4; i++) {
160+
res1 |= vals[pack_idx[i]] << (i * 8);
161+
res2 |= vals[4 + pack_idx[i]] << (i * 8);
162+
}
163+
164+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
165+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
166+
}
167+
};
168+
169+
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
170+
#pragma unroll
171+
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
172+
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
173+
}
174+
175+
wait_for_stage();
176+
};
177+
#pragma unroll
178+
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
179+
int n_tile_id = 0;
180+
181+
start_pipes(k_tile_id, n_tile_id);
182+
183+
while (n_tile_id < n_tiles) {
184+
#pragma unroll
185+
for (int pipe = 0; pipe < repack_stages; pipe++) {
186+
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
187+
n_tile_id + pipe + repack_stages - 1);
188+
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
189+
wait_for_stage();
190+
}
191+
n_tile_id += repack_stages;
192+
}
193+
}
194+
}
195+
196+
} // namespace marlin
197+
198+
#define CALL_IF(NUM_BITS) \
199+
else if (num_bits == NUM_BITS) { \
200+
cudaFuncSetAttribute( \
201+
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
202+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
203+
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
204+
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
205+
b_q_weight_ptr, out_ptr, size_k, size_n); \
206+
}
207+
208+
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
209+
int64_t size_n, int64_t num_bits) {
210+
// Verify compatibility with marlin tile of 16x64
211+
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
212+
" is not divisible by tile_k_size = ", marlin::tile_k_size);
213+
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
214+
" is not divisible by tile_n_size = ", marlin::tile_n_size);
215+
216+
TORCH_CHECK(num_bits == 4 || num_bits == 8,
217+
"num_bits must be 4 or 8. Got = ", num_bits);
218+
int const pack_factor = 32 / num_bits;
219+
220+
// Verify B
221+
TORCH_CHECK(b_q_weight.size(0) == size_k,
222+
"b_q_weight.size(0) = ", b_q_weight.size(0),
223+
" is not size_k = ", size_k);
224+
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
225+
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
226+
", size_n = ", size_n, ", pack_factor = ", pack_factor);
227+
228+
// Verify device and strides
229+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
230+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
231+
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
232+
233+
// Alloc buffers
234+
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
235+
auto options = torch::TensorOptions()
236+
.dtype(b_q_weight.dtype())
237+
.device(b_q_weight.device());
238+
torch::Tensor out = torch::empty(
239+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
240+
options);
241+
242+
// Get ptrs
243+
uint32_t const* b_q_weight_ptr =
244+
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
245+
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
246+
247+
// Get dev info
248+
int dev = b_q_weight.get_device();
249+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
250+
int blocks;
251+
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
252+
253+
int max_shared_mem = 0;
254+
cudaDeviceGetAttribute(&max_shared_mem,
255+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
256+
TORCH_CHECK(max_shared_mem > 0);
257+
258+
if (false) {
259+
}
260+
CALL_IF(4)
261+
CALL_IF(8)
262+
else {
263+
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
264+
}
265+
266+
return out;
267+
}
268+
269+
#endif

0 commit comments

Comments
 (0)