Skip to content

Commit 7197b79

Browse files
authored
[Inference] Use cuda core(int8_sq) for m <=4 in gemm_dequant OP (#9707)
1 parent b53f7eb commit 7197b79

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

csrc/gpu/int8_gemm_with_cutlass/gemm_dequant.cu

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,131 @@
1515
#include "gemm_dequant.h"
1616
#include "cutlass_helper.h"
1717

18+
template <typename Type, int CtaM, int CtaN, int Threads>
19+
__global__ void int8_sq(int8_t const* act,
20+
int8_t const* weight,
21+
float const* scale,
22+
Type* output,
23+
int m,
24+
int n,
25+
int k) {
26+
using VecType = int4;
27+
static constexpr int kStepK = 128 / (8 * sizeof(int8_t));
28+
static constexpr int CtaK = kStepK * Threads;
29+
int tile_id_m = blockIdx.x * CtaM;
30+
int tile_id_n = blockIdx.y * CtaN;
31+
int tid = threadIdx.x;
32+
int8_t tile_a[kStepK], tile_w[CtaN * kStepK];
33+
int acc[CtaM * CtaN];
34+
#pragma unroll
35+
for (int i = 0; i < CtaM * CtaN; ++i) {
36+
acc[i] = 0;
37+
}
38+
act += tile_id_m * k;
39+
weight += tile_id_n * k;
40+
scale += tile_id_n;
41+
output += tile_id_m * n + tile_id_n;
42+
for (int idx_k = tid * kStepK; idx_k < k; idx_k += CtaK) {
43+
#pragma unroll
44+
for (int i = 0; i < CtaN; ++i) {
45+
reinterpret_cast<VecType*>(tile_w)[i] =
46+
reinterpret_cast<VecType const*>(weight + i * k + idx_k)[0];
47+
}
48+
#pragma unroll
49+
for (int i = 0; i < CtaM; ++i) {
50+
reinterpret_cast<VecType*>(tile_a)[0] =
51+
reinterpret_cast<VecType const*>(act + i * k + idx_k)[0];
52+
#pragma unroll
53+
for (int j = 0; j < CtaN; ++j) {
54+
#pragma unroll
55+
for (int l = 0; l < kStepK; l += 4) {
56+
acc[i * CtaN + j] =
57+
__dp4a(reinterpret_cast<int*>(tile_a + l)[0],
58+
reinterpret_cast<int*>(tile_w + j * kStepK + l)[0],
59+
acc[i * CtaN + j]);
60+
}
61+
}
62+
}
63+
}
64+
65+
static constexpr int kWarpSize = 32;
66+
static constexpr int kWarpNum = Threads / kWarpSize;
67+
__shared__ int shmem[CtaM * CtaN * kWarpNum];
68+
int warp_id = tid / kWarpSize, lane_id = tid % kWarpSize;
69+
#pragma unroll
70+
for (int i = 0; i < CtaM; ++i) {
71+
#pragma unroll
72+
for (int j = 0; j < CtaN; ++j) {
73+
int val = acc[i * CtaN + j];
74+
val += __shfl_xor_sync(~0, val, 16);
75+
val += __shfl_xor_sync(~0, val, 8);
76+
val += __shfl_xor_sync(~0, val, 4);
77+
val += __shfl_xor_sync(~0, val, 2);
78+
val += __shfl_xor_sync(~0, val, 1);
79+
if (lane_id == 0) {
80+
shmem[i * CtaN + j + warp_id * CtaM * CtaN] = val;
81+
}
82+
}
83+
}
84+
__syncthreads();
85+
#pragma unroll
86+
for (int ii = tid; ii < CtaM * CtaN; ii += Threads) {
87+
int mid = ii / CtaN, nid = ii % CtaN;
88+
int val = 0;
89+
#pragma unroll
90+
for (int jj = 0; jj < kWarpNum; ++jj) {
91+
val += shmem[jj * CtaM * CtaN + ii];
92+
}
93+
output[mid * n + nid] = static_cast<Type>(static_cast<float>(val)*(float)*(scale+nid));
94+
}
95+
}
96+
97+
template <typename InputType,
98+
typename OutputType,
99+
int32_t TILE_M,
100+
int32_t TILE_N,
101+
int32_t BLOCK_SIZE>
102+
void int8_sq_kernel(GemmDequantParams const& params) {
103+
dim3 block(BLOCK_SIZE);
104+
dim3 grid(params.m / TILE_M, params.n / TILE_N);
105+
int8_sq<OutputType, TILE_M, TILE_N, BLOCK_SIZE>
106+
<<<grid, block, 0, params.stream>>>(
107+
reinterpret_cast<InputType const*>(params.act),
108+
reinterpret_cast<InputType const*>(params.weight),
109+
reinterpret_cast<float const*>(params.scale),
110+
reinterpret_cast<OutputType*>(params.output),
111+
params.m,
112+
params.n,
113+
params.k);
114+
}
115+
116+
template <typename InputType,
117+
typename OutputType,
118+
int TILE_M,
119+
int TILE_N,
120+
int BLOCK_SIZE>
121+
bool int8_sq_kernel_caller(GemmDequantParams const& params) {
122+
constexpr int cudaCoreGemmTemplateMaxM = 16;
123+
if (params.m == TILE_M) {
124+
int8_sq_kernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>(
125+
params);
126+
return true;
127+
}
128+
if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) {
129+
return int8_sq_kernel_caller<InputType,
130+
OutputType,
131+
TILE_M + 1,
132+
TILE_N,
133+
BLOCK_SIZE>(params);
134+
}
135+
return false;
136+
}
137+
138+
template <typename InputType, typename OutputType>
139+
bool int8_sq_kernel_launcher(GemmDequantParams const& params) {
140+
return int8_sq_kernel_caller<InputType, OutputType, 1, 2, 256>(params);
141+
}
142+
18143
template <paddle::DataType D, typename T>
19144
void RunGemmDequant(const int8_t* a,
20145
const int8_t* b, // Transposed
@@ -114,6 +239,49 @@ std::vector<paddle::Tensor> GemmDequant(const paddle::Tensor& x,
114239
int64_t m = x_dims[x_dims.size() - 2];
115240
int64_t k = x_dims[x_dims.size() - 1];
116241
int64_t n = y_dims[y_dims.size() - 2];
242+
243+
if(m <= 4)
244+
{
245+
if (out_dtype == "bfloat16") {
246+
paddle::Tensor out =
247+
paddle::empty({m, n}, paddle::DataType::BFLOAT16, x.place());
248+
GemmDequantParams params = {
249+
reinterpret_cast<const void*>(x.data<int8_t>()),
250+
reinterpret_cast<const void*>(y.data<int8_t>()),
251+
reinterpret_cast<const void*>(scale.data<float>()),
252+
reinterpret_cast<void*>(out.data<paddle::bfloat16>()),
253+
m,
254+
n,
255+
k,
256+
x.stream()
257+
};
258+
if (!int8_sq_kernel_launcher<int8_t, __nv_bfloat16>(params)) {
259+
PADDLE_THROW(common::errors::Fatal("gemm dequamt kernel run error"));
260+
}
261+
return {out};
262+
} else if (out_dtype == "float16") {
263+
paddle::Tensor out =
264+
paddle::empty({m, n}, paddle::DataType::FLOAT16, x.place());
265+
GemmDequantParams params = {
266+
reinterpret_cast<const void*>(x.data<int8_t>()),
267+
reinterpret_cast<const void*>(y.data<int8_t>()),
268+
reinterpret_cast<const void*>(scale.data<float>()),
269+
reinterpret_cast<void*>(out.data<paddle::float16>()),
270+
m,
271+
n,
272+
k,
273+
x.stream()
274+
};
275+
if (!int8_sq_kernel_launcher<int8_t, half>(params)) {
276+
PADDLE_THROW(common::errors::Fatal("gemm dequamt kernel run error"));
277+
}
278+
return {out};
279+
} else {
280+
PADDLE_THROW(phi::errors::InvalidArgument(
281+
"only support bfloat16 and float16, but got %s", out_dtype));
282+
}
283+
}
284+
117285
if (out_dtype == "bfloat16") {
118286
paddle::Tensor out = paddle::empty({m, n}, paddle::DataType::BFLOAT16, x.place());
119287
RunGemmDequant<paddle::DataType::BFLOAT16, paddle::bfloat16>(x.data<int8_t>(),

csrc/gpu/int8_gemm_with_cutlass/gemm_dequant.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,3 +1582,12 @@ class GemmDequant {
15821582
};
15831583

15841584
} // namespace cutlass
1585+
1586+
typedef struct {
1587+
void const* act;
1588+
void const* weight;
1589+
void const* scale;
1590+
void* output;
1591+
int32_t m, n, k;
1592+
cudaStream_t stream;
1593+
} GemmDequantParams;

0 commit comments

Comments
 (0)