@@ -538,6 +538,7 @@ __global__ void Marlin(
538
538
int prob_n, // output dimension n
539
539
int prob_k, // reduction dimension k
540
540
int * locks, // extra global storage for barrier synchronization
541
+ bool use_atomic_add, // whether to use atomic add to reduce
541
542
bool use_fp32_reduce // whether to use fp32 global reduce
542
543
) {
543
544
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
@@ -1542,7 +1543,17 @@ __global__ void Marlin(
1542
1543
i < div_ceil (16 * thread_m_blocks, threads / (2 * thread_n_blocks));
1543
1544
i++) {
1544
1545
if (c_gl_wr < c_gl_wr_end) {
1545
- C[c_gl_wr] = sh_red[c_sh_rd];
1546
+ if (use_atomic_add) {
1547
+ scalar_t2* C_half2 = reinterpret_cast <scalar_t2*>(&C[c_gl_wr]);
1548
+ scalar_t2* sh_red_half2 =
1549
+ reinterpret_cast <scalar_t2*>(&sh_red[c_sh_rd]);
1550
+ #pragma unroll
1551
+ for (int a = 0 ; a < 4 ; a++) {
1552
+ atomicAdd (&C_half2[a], sh_red_half2[a]);
1553
+ }
1554
+ } else {
1555
+ C[c_gl_wr] = sh_red[c_sh_rd];
1556
+ }
1546
1557
c_gl_wr += c_gl_wr_delta;
1547
1558
c_sh_rd += c_sh_rd_delta;
1548
1559
}
@@ -1703,8 +1714,8 @@ __global__ void Marlin(
1703
1714
}
1704
1715
}
1705
1716
1706
- if (slice_count > 1 ) { // only globally reduce if there is more than one
1707
- // block in a slice
1717
+ if (slice_count > 1 && !use_atomic_add) {
1718
+ // only globally reduce if there is more than one block in a slice
1708
1719
barrier_acquire (&locks[slice_col], slice_idx);
1709
1720
if (use_fp32_reduce) {
1710
1721
global_reduce_fp32 (slice_idx == 0 , last);
@@ -1713,7 +1724,8 @@ __global__ void Marlin(
1713
1724
}
1714
1725
barrier_release (&locks[slice_col], last);
1715
1726
}
1716
- if (last) // only the last block in a slice actually writes the result
1727
+ if (last || use_atomic_add)
1728
+ // only the last block in a slice actuallywrites the result
1717
1729
write_result ();
1718
1730
slice_row = 0 ;
1719
1731
slice_col_par++;
@@ -1768,7 +1780,8 @@ __global__ void Marlin(
1768
1780
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
1769
1781
<<<blocks, NUM_THREADS, max_shared_mem, stream>>> ( \
1770
1782
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
1771
- num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
1783
+ num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
1784
+ use_fp32_reduce); \
1772
1785
} \
1773
1786
}
1774
1787
@@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
2062
2075
vllm::ScalarType const & q_type, bool has_act_order,
2063
2076
bool is_k_full, bool has_zp, int num_groups, int group_size,
2064
2077
int dev, cudaStream_t stream, int thread_k, int thread_n,
2065
- int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
2078
+ int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce,
2079
+ bool is_zp_float) {
2066
2080
if (has_zp) {
2067
2081
TORCH_CHECK (
2068
2082
q_type == vllm::kU4 || q_type == vllm::kU8 ,
@@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
2243
2257
torch::Tensor& workspace,
2244
2258
vllm::ScalarTypeId const & b_q_type_id,
2245
2259
int64_t size_m, int64_t size_n, int64_t size_k,
2246
- bool is_k_full, bool has_zp,
2260
+ bool is_k_full, bool has_zp, bool use_atomic_add,
2247
2261
bool use_fp32_reduce, bool is_zp_float) {
2248
2262
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id (b_q_type_id);
2249
2263
if (has_zp) {
@@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
2306
2320
// Alloc buffers
2307
2321
const at::cuda::OptionalCUDAGuard device_guard (device_of (a));
2308
2322
auto options = torch::TensorOptions ().dtype (a.dtype ()).device (a.device ());
2309
- torch::Tensor c = torch::empty ({size_m, size_n}, options);
2310
- torch::Tensor a_tmp = torch::empty ({size_m, size_k}, options);
2323
+ torch::Tensor c;
2324
+ if (use_atomic_add) {
2325
+ c = torch::zeros ({size_m, size_n}, options);
2326
+ } else {
2327
+ c = torch::empty ({size_m, size_n}, options);
2328
+ }
2329
+
2330
+ torch::Tensor a_tmp;
2331
+ bool has_act_order = g_idx.size (0 ) != 0 ;
2332
+ if (has_act_order) {
2333
+ a_tmp = torch::empty ({size_m, size_k}, options);
2334
+ } else {
2335
+ a_tmp = torch::empty ({0 }, options);
2336
+ }
2311
2337
2312
2338
// Alloc C tmp buffer that is going to be used for the global reduce
2339
+ torch::Tensor c_tmp;
2313
2340
int reduce_max_m = marlin::determine_reduce_max_m (size_m, marlin::max_par);
2314
2341
int reduce_n = size_n;
2315
2342
auto options_fp32 =
2316
2343
torch::TensorOptions ().dtype (at::kFloat ).device (a.device ());
2317
- if (!use_fp32_reduce) {
2344
+ if (use_fp32_reduce) {
2345
+ c_tmp = torch::empty ({reduce_max_m, reduce_n}, options_fp32);
2346
+ } else {
2318
2347
reduce_max_m = 0 ;
2319
2348
reduce_n = 0 ;
2349
+ c_tmp = torch::empty ({0 }, options_fp32);
2320
2350
}
2321
- torch::Tensor c_tmp = torch::empty ({reduce_max_m, reduce_n}, options_fp32);
2322
2351
2323
2352
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
2324
2353
// auto -1)
@@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
2339
2368
// Detect groupsize and act_order
2340
2369
int num_groups = -1 ;
2341
2370
int group_size = -1 ;
2342
- bool has_act_order = g_idx.size (0 ) != 0 ;
2343
2371
2344
2372
int rank = b_scales.sizes ().size ();
2345
2373
TORCH_CHECK (rank == 2 , " b_scales rank = " , rank, " is not 2" );
@@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
2407
2435
a_tmp.data_ptr <at::Half>(), size_m, size_n, size_k,
2408
2436
workspace.data_ptr (), b_q_type, has_act_order, is_k_full, has_zp,
2409
2437
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream (dev),
2410
- thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2438
+ thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
2439
+ use_fp32_reduce, is_zp_float);
2411
2440
} else if (a.scalar_type () == at::ScalarType::BFloat16) {
2412
2441
marlin::marlin_mm<nv_bfloat16>(
2413
2442
a.data_ptr <at::BFloat16>(), b_q_weight.data_ptr (),
@@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
2416
2445
perm.data_ptr (), a_tmp.data_ptr <at::BFloat16>(), size_m, size_n, size_k,
2417
2446
workspace.data_ptr (), b_q_type, has_act_order, is_k_full, has_zp,
2418
2447
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream (dev),
2419
- thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2448
+ thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
2449
+ use_fp32_reduce, is_zp_float);
2420
2450
} else {
2421
2451
TORCH_CHECK (false , " gpt_marlin_gemm only supports bfloat16 and float16" );
2422
2452
}
0 commit comments