Skip to content

Commit 196cf02

Browse files
committed
optimize performance of gptq marlin kernel when n is small
Signed-off-by: Jinzhen Lin <[email protected]>
1 parent f35f8e2 commit 196cf02

File tree

5 files changed

+62
-16
lines changed

5 files changed

+62
-16
lines changed

csrc/quantization/gptq_marlin/gptq_marlin.cu

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ __global__ void Marlin(
538538
int prob_n, // output dimension n
539539
int prob_k, // reduction dimension k
540540
int* locks, // extra global storage for barrier synchronization
541+
bool use_atomic_add, // whether to use atomic add to reduce
541542
bool use_fp32_reduce // whether to use fp32 global reduce
542543
) {
543544
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
@@ -1542,7 +1543,17 @@ __global__ void Marlin(
15421543
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
15431544
i++) {
15441545
if (c_gl_wr < c_gl_wr_end) {
1545-
C[c_gl_wr] = sh_red[c_sh_rd];
1546+
if (use_atomic_add) {
1547+
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[c_gl_wr]);
1548+
scalar_t2* sh_red_half2 =
1549+
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
1550+
#pragma unroll
1551+
for (int a = 0; a < 4; a++) {
1552+
atomicAdd(&C_half2[a], sh_red_half2[a]);
1553+
}
1554+
} else {
1555+
C[c_gl_wr] = sh_red[c_sh_rd];
1556+
}
15461557
c_gl_wr += c_gl_wr_delta;
15471558
c_sh_rd += c_sh_rd_delta;
15481559
}
@@ -1703,8 +1714,8 @@ __global__ void Marlin(
17031714
}
17041715
}
17051716

1706-
if (slice_count > 1) { // only globally reduce if there is more than one
1707-
// block in a slice
1717+
if (slice_count > 1 && !use_atomic_add) {
1718+
// only globally reduce if there is more than one block in a slice
17081719
barrier_acquire(&locks[slice_col], slice_idx);
17091720
if (use_fp32_reduce) {
17101721
global_reduce_fp32(slice_idx == 0, last);
@@ -1713,7 +1724,8 @@ __global__ void Marlin(
17131724
}
17141725
barrier_release(&locks[slice_col], last);
17151726
}
1716-
if (last) // only the last block in a slice actually writes the result
1727+
if (last || use_atomic_add)
1728+
// only the last block in a slice actuallywrites the result
17171729
write_result();
17181730
slice_row = 0;
17191731
slice_col_par++;
@@ -1768,7 +1780,8 @@ __global__ void Marlin(
17681780
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
17691781
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
17701782
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
1771-
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
1783+
num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
1784+
use_fp32_reduce); \
17721785
} \
17731786
}
17741787

@@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
20622075
vllm::ScalarType const& q_type, bool has_act_order,
20632076
bool is_k_full, bool has_zp, int num_groups, int group_size,
20642077
int dev, cudaStream_t stream, int thread_k, int thread_n,
2065-
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
2078+
int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce,
2079+
bool is_zp_float) {
20662080
if (has_zp) {
20672081
TORCH_CHECK(
20682082
q_type == vllm::kU4 || q_type == vllm::kU8,
@@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
22432257
torch::Tensor& workspace,
22442258
vllm::ScalarTypeId const& b_q_type_id,
22452259
int64_t size_m, int64_t size_n, int64_t size_k,
2246-
bool is_k_full, bool has_zp,
2260+
bool is_k_full, bool has_zp, bool use_atomic_add,
22472261
bool use_fp32_reduce, bool is_zp_float) {
22482262
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
22492263
if (has_zp) {
@@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
23062320
// Alloc buffers
23072321
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
23082322
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
2309-
torch::Tensor c = torch::empty({size_m, size_n}, options);
2310-
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
2323+
torch::Tensor c;
2324+
if (use_atomic_add) {
2325+
c = torch::zeros({size_m, size_n}, options);
2326+
} else {
2327+
c = torch::empty({size_m, size_n}, options);
2328+
}
2329+
2330+
torch::Tensor a_tmp;
2331+
bool has_act_order = g_idx.size(0) != 0;
2332+
if (has_act_order) {
2333+
a_tmp = torch::empty({size_m, size_k}, options);
2334+
} else {
2335+
a_tmp = torch::empty({0}, options);
2336+
}
23112337

23122338
// Alloc C tmp buffer that is going to be used for the global reduce
2339+
torch::Tensor c_tmp;
23132340
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
23142341
int reduce_n = size_n;
23152342
auto options_fp32 =
23162343
torch::TensorOptions().dtype(at::kFloat).device(a.device());
2317-
if (!use_fp32_reduce) {
2344+
if (use_fp32_reduce) {
2345+
c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
2346+
} else {
23182347
reduce_max_m = 0;
23192348
reduce_n = 0;
2349+
c_tmp = torch::empty({0}, options_fp32);
23202350
}
2321-
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
23222351

23232352
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
23242353
// auto -1)
@@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
23392368
// Detect groupsize and act_order
23402369
int num_groups = -1;
23412370
int group_size = -1;
2342-
bool has_act_order = g_idx.size(0) != 0;
23432371

23442372
int rank = b_scales.sizes().size();
23452373
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
@@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
24072435
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
24082436
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
24092437
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
2410-
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2438+
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
2439+
use_fp32_reduce, is_zp_float);
24112440
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
24122441
marlin::marlin_mm<nv_bfloat16>(
24132442
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
@@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
24162445
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
24172446
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
24182447
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
2419-
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2448+
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
2449+
use_fp32_reduce, is_zp_float);
24202450
} else {
24212451
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
24222452
}

csrc/torch_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
255255
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
256256
"int b_q_type, "
257257
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
258-
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
258+
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
259+
"bool is_zp_float) -> Tensor");
259260
// conditionally compiled so impl registration is in source file
260261

261262
// gptq_marlin repack from GPTQ.

vllm/_custom_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,12 +713,14 @@ def gptq_marlin_gemm(a: torch.Tensor,
713713
size_k: int,
714714
is_k_full: bool,
715715
has_zp: bool = False,
716+
use_atomic_add: bool = False,
716717
use_fp32_reduce: bool = False,
717718
is_zp_float: bool = False) -> torch.Tensor:
718719
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
719720
g_idx, perm, workspace, b_q_type.id,
720721
size_m, size_n, size_k, is_k_full,
721-
has_zp, use_fp32_reduce, is_zp_float)
722+
has_zp, use_atomic_add,
723+
use_fp32_reduce, is_zp_float)
722724

723725

724726
# fp8 marlin

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
VLLM_DP_SIZE: int = 1
9797
VLLM_DP_MASTER_IP: str = ""
9898
VLLM_DP_MASTER_PORT: int = 0
99+
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
99100

100101

101102
def get_default_cache_root():
@@ -636,6 +637,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
636637
# Whether to use S3 path for model loading in CI via RunAI Streamer
637638
"VLLM_CI_USE_S3":
638639
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
640+
641+
# Whether to use atomicAdd reduce in gptq/awq marlin kernel.
642+
"VLLM_MARLIN_USE_ATOMIC_ADD":
643+
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
639644
}
640645

641646
# end-env-vars-definition

vllm/model_executor/layers/quantization/utils/marlin_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy
66
import torch
77

8+
import vllm.envs as envs
89
from vllm import _custom_ops as ops
910
from vllm.model_executor.layers.linear import LinearBase
1011
from vllm.platforms import current_platform
@@ -303,10 +304,17 @@ def apply_gptq_marlin_linear(
303304
input_size_per_partition: int,
304305
is_k_full: bool,
305306
bias: Optional[torch.Tensor] = None,
307+
use_atomic_add: Optional[bool] = None,
306308
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
307309
reshaped_x = input.reshape(-1, input.shape[-1])
308310
out_shape = input.shape[:-1] + (output_size_per_partition, )
309311

312+
if use_atomic_add is None:
313+
use_atomic_add = envs.VLLM_MARLIN_USE_ATOMIC_ADD
314+
315+
if output_size_per_partition > 2048:
316+
use_atomic_add = False
317+
310318
output = ops.gptq_marlin_gemm(reshaped_x,
311319
weight,
312320
weight_scale,

0 commit comments

Comments
 (0)