|
15 | 15 | #include "gemm_dequant.h" |
16 | 16 | #include "cutlass_helper.h" |
17 | 17 |
|
| 18 | +template <typename Type, int CtaM, int CtaN, int Threads> |
| 19 | +__global__ void int8_sq(int8_t const* act, |
| 20 | + int8_t const* weight, |
| 21 | + float const* scale, |
| 22 | + Type* output, |
| 23 | + int m, |
| 24 | + int n, |
| 25 | + int k) { |
| 26 | + using VecType = int4; |
| 27 | + static constexpr int kStepK = 128 / (8 * sizeof(int8_t)); |
| 28 | + static constexpr int CtaK = kStepK * Threads; |
| 29 | + int tile_id_m = blockIdx.x * CtaM; |
| 30 | + int tile_id_n = blockIdx.y * CtaN; |
| 31 | + int tid = threadIdx.x; |
| 32 | + int8_t tile_a[kStepK], tile_w[CtaN * kStepK]; |
| 33 | + int acc[CtaM * CtaN]; |
| 34 | +#pragma unroll |
| 35 | + for (int i = 0; i < CtaM * CtaN; ++i) { |
| 36 | + acc[i] = 0; |
| 37 | + } |
| 38 | + act += tile_id_m * k; |
| 39 | + weight += tile_id_n * k; |
| 40 | + scale += tile_id_n; |
| 41 | + output += tile_id_m * n + tile_id_n; |
| 42 | + for (int idx_k = tid * kStepK; idx_k < k; idx_k += CtaK) { |
| 43 | +#pragma unroll |
| 44 | + for (int i = 0; i < CtaN; ++i) { |
| 45 | + reinterpret_cast<VecType*>(tile_w)[i] = |
| 46 | + reinterpret_cast<VecType const*>(weight + i * k + idx_k)[0]; |
| 47 | + } |
| 48 | +#pragma unroll |
| 49 | + for (int i = 0; i < CtaM; ++i) { |
| 50 | + reinterpret_cast<VecType*>(tile_a)[0] = |
| 51 | + reinterpret_cast<VecType const*>(act + i * k + idx_k)[0]; |
| 52 | +#pragma unroll |
| 53 | + for (int j = 0; j < CtaN; ++j) { |
| 54 | +#pragma unroll |
| 55 | + for (int l = 0; l < kStepK; l += 4) { |
| 56 | + acc[i * CtaN + j] = |
| 57 | + __dp4a(reinterpret_cast<int*>(tile_a + l)[0], |
| 58 | + reinterpret_cast<int*>(tile_w + j * kStepK + l)[0], |
| 59 | + acc[i * CtaN + j]); |
| 60 | + } |
| 61 | + } |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + static constexpr int kWarpSize = 32; |
| 66 | + static constexpr int kWarpNum = Threads / kWarpSize; |
| 67 | + __shared__ int shmem[CtaM * CtaN * kWarpNum]; |
| 68 | + int warp_id = tid / kWarpSize, lane_id = tid % kWarpSize; |
| 69 | +#pragma unroll |
| 70 | + for (int i = 0; i < CtaM; ++i) { |
| 71 | +#pragma unroll |
| 72 | + for (int j = 0; j < CtaN; ++j) { |
| 73 | + int val = acc[i * CtaN + j]; |
| 74 | + val += __shfl_xor_sync(~0, val, 16); |
| 75 | + val += __shfl_xor_sync(~0, val, 8); |
| 76 | + val += __shfl_xor_sync(~0, val, 4); |
| 77 | + val += __shfl_xor_sync(~0, val, 2); |
| 78 | + val += __shfl_xor_sync(~0, val, 1); |
| 79 | + if (lane_id == 0) { |
| 80 | + shmem[i * CtaN + j + warp_id * CtaM * CtaN] = val; |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + __syncthreads(); |
| 85 | +#pragma unroll |
| 86 | + for (int ii = tid; ii < CtaM * CtaN; ii += Threads) { |
| 87 | + int mid = ii / CtaN, nid = ii % CtaN; |
| 88 | + int val = 0; |
| 89 | +#pragma unroll |
| 90 | + for (int jj = 0; jj < kWarpNum; ++jj) { |
| 91 | + val += shmem[jj * CtaM * CtaN + ii]; |
| 92 | + } |
| 93 | + output[mid * n + nid] = static_cast<Type>(static_cast<float>(val)*(float)*(scale+nid)); |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +template <typename InputType, |
| 98 | + typename OutputType, |
| 99 | + int32_t TILE_M, |
| 100 | + int32_t TILE_N, |
| 101 | + int32_t BLOCK_SIZE> |
| 102 | +void int8_sq_kernel(GemmDequantParams const& params) { |
| 103 | + dim3 block(BLOCK_SIZE); |
| 104 | + dim3 grid(params.m / TILE_M, params.n / TILE_N); |
| 105 | + int8_sq<OutputType, TILE_M, TILE_N, BLOCK_SIZE> |
| 106 | + <<<grid, block, 0, params.stream>>>( |
| 107 | + reinterpret_cast<InputType const*>(params.act), |
| 108 | + reinterpret_cast<InputType const*>(params.weight), |
| 109 | + reinterpret_cast<float const*>(params.scale), |
| 110 | + reinterpret_cast<OutputType*>(params.output), |
| 111 | + params.m, |
| 112 | + params.n, |
| 113 | + params.k); |
| 114 | +} |
| 115 | + |
| 116 | +template <typename InputType, |
| 117 | + typename OutputType, |
| 118 | + int TILE_M, |
| 119 | + int TILE_N, |
| 120 | + int BLOCK_SIZE> |
| 121 | +bool int8_sq_kernel_caller(GemmDequantParams const& params) { |
| 122 | + constexpr int cudaCoreGemmTemplateMaxM = 16; |
| 123 | + if (params.m == TILE_M) { |
| 124 | + int8_sq_kernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>( |
| 125 | + params); |
| 126 | + return true; |
| 127 | + } |
| 128 | + if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) { |
| 129 | + return int8_sq_kernel_caller<InputType, |
| 130 | + OutputType, |
| 131 | + TILE_M + 1, |
| 132 | + TILE_N, |
| 133 | + BLOCK_SIZE>(params); |
| 134 | + } |
| 135 | + return false; |
| 136 | +} |
| 137 | + |
| 138 | +template <typename InputType, typename OutputType> |
| 139 | +bool int8_sq_kernel_launcher(GemmDequantParams const& params) { |
| 140 | + return int8_sq_kernel_caller<InputType, OutputType, 1, 2, 256>(params); |
| 141 | +} |
| 142 | + |
18 | 143 | template <paddle::DataType D, typename T> |
19 | 144 | void RunGemmDequant(const int8_t* a, |
20 | 145 | const int8_t* b, // Transposed |
@@ -114,6 +239,49 @@ std::vector<paddle::Tensor> GemmDequant(const paddle::Tensor& x, |
114 | 239 | int64_t m = x_dims[x_dims.size() - 2]; |
115 | 240 | int64_t k = x_dims[x_dims.size() - 1]; |
116 | 241 | int64_t n = y_dims[y_dims.size() - 2]; |
| 242 | + |
| 243 | + if(m <= 4) |
| 244 | + { |
| 245 | + if (out_dtype == "bfloat16") { |
| 246 | + paddle::Tensor out = |
| 247 | + paddle::empty({m, n}, paddle::DataType::BFLOAT16, x.place()); |
| 248 | + GemmDequantParams params = { |
| 249 | + reinterpret_cast<const void*>(x.data<int8_t>()), |
| 250 | + reinterpret_cast<const void*>(y.data<int8_t>()), |
| 251 | + reinterpret_cast<const void*>(scale.data<float>()), |
| 252 | + reinterpret_cast<void*>(out.data<paddle::bfloat16>()), |
| 253 | + m, |
| 254 | + n, |
| 255 | + k, |
| 256 | + x.stream() |
| 257 | + }; |
| 258 | + if (!int8_sq_kernel_launcher<int8_t, __nv_bfloat16>(params)) { |
| 259 | + PADDLE_THROW(common::errors::Fatal("gemm dequamt kernel run error")); |
| 260 | + } |
| 261 | + return {out}; |
| 262 | + } else if (out_dtype == "float16") { |
| 263 | + paddle::Tensor out = |
| 264 | + paddle::empty({m, n}, paddle::DataType::FLOAT16, x.place()); |
| 265 | + GemmDequantParams params = { |
| 266 | + reinterpret_cast<const void*>(x.data<int8_t>()), |
| 267 | + reinterpret_cast<const void*>(y.data<int8_t>()), |
| 268 | + reinterpret_cast<const void*>(scale.data<float>()), |
| 269 | + reinterpret_cast<void*>(out.data<paddle::float16>()), |
| 270 | + m, |
| 271 | + n, |
| 272 | + k, |
| 273 | + x.stream() |
| 274 | + }; |
| 275 | + if (!int8_sq_kernel_launcher<int8_t, half>(params)) { |
| 276 | + PADDLE_THROW(common::errors::Fatal("gemm dequamt kernel run error")); |
| 277 | + } |
| 278 | + return {out}; |
| 279 | + } else { |
| 280 | + PADDLE_THROW(phi::errors::InvalidArgument( |
| 281 | + "only support bfloat16 and float16, but got %s", out_dtype)); |
| 282 | + } |
| 283 | + } |
| 284 | + |
117 | 285 | if (out_dtype == "bfloat16") { |
118 | 286 | paddle::Tensor out = paddle::empty({m, n}, paddle::DataType::BFLOAT16, x.place()); |
119 | 287 | RunGemmDequant<paddle::DataType::BFLOAT16, paddle::bfloat16>(x.data<int8_t>(), |
|
0 commit comments