|
| 1 | +#include "marlin.cuh" |
| 2 | + |
| 3 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
| 4 | + |
| 5 | +namespace marlin { |
| 6 | + |
| 7 | +template <int const num_threads, int const num_bits, bool const has_perm> |
| 8 | +__global__ void awq_marlin_repack_kernel( |
| 9 | + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, |
| 10 | + int size_k, int size_n) {} |
| 11 | + |
| 12 | +} // namespace marlin |
| 13 | + |
| 14 | +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, |
| 15 | + int64_t size_k, int64_t size_n, |
| 16 | + int64_t num_bits) { |
| 17 | + TORCH_CHECK_NOT_IMPLEMENTED( |
| 18 | + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); |
| 19 | + return torch::empty({1, 1}); |
| 20 | +} |
| 21 | + |
| 22 | +#else |
| 23 | + |
| 24 | +namespace marlin { |
| 25 | + |
| 26 | +template <int const num_threads, int const num_bits> |
| 27 | +__global__ void awq_marlin_repack_kernel( |
| 28 | + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, |
| 29 | + int size_k, int size_n) { |
| 30 | + constexpr int pack_factor = 32 / num_bits; |
| 31 | + |
| 32 | + int k_tiles = size_k / tile_k_size; |
| 33 | + int n_tiles = size_n / tile_n_size; |
| 34 | + int block_k_tiles = div_ceil(k_tiles, gridDim.x); |
| 35 | + |
| 36 | + int start_k_tile = blockIdx.x * block_k_tiles; |
| 37 | + if (start_k_tile >= k_tiles) { |
| 38 | + return; |
| 39 | + } |
| 40 | + |
| 41 | + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); |
| 42 | + |
| 43 | + // Wait until the next thread tile has been loaded to shared memory. |
| 44 | + auto wait_for_stage = [&]() { |
| 45 | + // We only have `stages - 2` active fetches since we are double buffering |
| 46 | + // and can only issue the next fetch when it is guaranteed that the previous |
| 47 | + // shared memory load is fully complete (as it may otherwise be |
| 48 | + // overwritten). |
| 49 | + cp_async_wait<repack_stages - 2>(); |
| 50 | + __syncthreads(); |
| 51 | + }; |
| 52 | + |
| 53 | + extern __shared__ int4 sh[]; |
| 54 | + |
| 55 | + constexpr int tile_n_ints = tile_n_size / pack_factor; |
| 56 | + |
| 57 | + constexpr int stage_n_threads = tile_n_ints / 4; |
| 58 | + constexpr int stage_k_threads = tile_k_size; |
| 59 | + constexpr int stage_size = stage_k_threads * stage_n_threads; |
| 60 | + |
| 61 | + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { |
| 62 | + if (n_tile_id >= n_tiles) { |
| 63 | + cp_async_fence(); |
| 64 | + return; |
| 65 | + } |
| 66 | + |
| 67 | + int first_n = n_tile_id * tile_n_size; |
| 68 | + int first_n_packed = first_n / pack_factor; |
| 69 | + |
| 70 | + int4* sh_ptr = sh + stage_size * pipe; |
| 71 | + |
| 72 | + if (threadIdx.x < stage_size) { |
| 73 | + int k_id = threadIdx.x / stage_n_threads; |
| 74 | + int n_id = threadIdx.x % stage_n_threads; |
| 75 | + |
| 76 | + int first_k = k_tile_id * tile_k_size; |
| 77 | + |
| 78 | + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], |
| 79 | + reinterpret_cast<int4 const*>( |
| 80 | + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + |
| 81 | + first_n_packed + (n_id * 4)]))); |
| 82 | + } |
| 83 | + |
| 84 | + cp_async_fence(); |
| 85 | + }; |
| 86 | + |
| 87 | + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { |
| 88 | + if (n_tile_id >= n_tiles) { |
| 89 | + return; |
| 90 | + } |
| 91 | + |
| 92 | + int warp_id = threadIdx.x / 32; |
| 93 | + int th_id = threadIdx.x % 32; |
| 94 | + |
| 95 | + if (warp_id >= 4) { |
| 96 | + return; |
| 97 | + } |
| 98 | + |
| 99 | + int tc_col = th_id / 4; |
| 100 | + int tc_row = (th_id % 4) * 2; |
| 101 | + |
| 102 | + constexpr int tc_offsets[4] = {0, 1, 8, 9}; |
| 103 | + |
| 104 | + int cur_n = warp_id * 16 + tc_col; |
| 105 | + int cur_n_packed = cur_n / pack_factor; |
| 106 | + int cur_n_pos = cur_n % pack_factor; |
| 107 | + |
| 108 | + constexpr int sh_stride = tile_n_ints; |
| 109 | + constexpr uint32_t mask = (1 << num_bits) - 1; |
| 110 | + |
| 111 | + int4* sh_stage_ptr = sh + stage_size * pipe; |
| 112 | + uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr); |
| 113 | + |
| 114 | + // Undo interleaving |
| 115 | + int cur_n_pos_unpacked; |
| 116 | + if constexpr (num_bits == 4) { |
| 117 | + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; |
| 118 | + cur_n_pos_unpacked = undo_pack[cur_n_pos]; |
| 119 | + } else { |
| 120 | + constexpr int undo_pack[4] = {0, 2, 1, 3}; |
| 121 | + cur_n_pos_unpacked = undo_pack[cur_n_pos]; |
| 122 | + } |
| 123 | + |
| 124 | + uint32_t vals[8]; |
| 125 | + #pragma unroll |
| 126 | + for (int i = 0; i < 4; i++) { |
| 127 | + int cur_elem = tc_row + tc_offsets[i]; |
| 128 | + |
| 129 | + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; |
| 130 | + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + |
| 131 | + sh_stride * cur_elem]; |
| 132 | + |
| 133 | + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; |
| 134 | + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; |
| 135 | + } |
| 136 | + |
| 137 | + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; |
| 138 | + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; |
| 139 | + |
| 140 | + // Result of: |
| 141 | + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h |
| 142 | + if constexpr (num_bits == 4) { |
| 143 | + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; |
| 144 | + |
| 145 | + uint32_t res = 0; |
| 146 | + #pragma unroll |
| 147 | + for (int i = 0; i < 8; i++) { |
| 148 | + res |= vals[pack_idx[i]] << (i * 4); |
| 149 | + } |
| 150 | + |
| 151 | + out_ptr[out_offset + th_id * 4 + warp_id] = res; |
| 152 | + |
| 153 | + } else { |
| 154 | + constexpr int pack_idx[4] = {0, 2, 1, 3}; |
| 155 | + |
| 156 | + uint32_t res1 = 0; |
| 157 | + uint32_t res2 = 0; |
| 158 | + #pragma unroll |
| 159 | + for (int i = 0; i < 4; i++) { |
| 160 | + res1 |= vals[pack_idx[i]] << (i * 8); |
| 161 | + res2 |= vals[4 + pack_idx[i]] << (i * 8); |
| 162 | + } |
| 163 | + |
| 164 | + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; |
| 165 | + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; |
| 166 | + } |
| 167 | + }; |
| 168 | + |
| 169 | + auto start_pipes = [&](int k_tile_id, int n_tile_id) { |
| 170 | + #pragma unroll |
| 171 | + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { |
| 172 | + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); |
| 173 | + } |
| 174 | + |
| 175 | + wait_for_stage(); |
| 176 | + }; |
| 177 | + #pragma unroll |
| 178 | + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { |
| 179 | + int n_tile_id = 0; |
| 180 | + |
| 181 | + start_pipes(k_tile_id, n_tile_id); |
| 182 | + |
| 183 | + while (n_tile_id < n_tiles) { |
| 184 | + #pragma unroll |
| 185 | + for (int pipe = 0; pipe < repack_stages; pipe++) { |
| 186 | + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, |
| 187 | + n_tile_id + pipe + repack_stages - 1); |
| 188 | + repack_tile(pipe, k_tile_id, n_tile_id + pipe); |
| 189 | + wait_for_stage(); |
| 190 | + } |
| 191 | + n_tile_id += repack_stages; |
| 192 | + } |
| 193 | + } |
| 194 | +} |
| 195 | + |
| 196 | +} // namespace marlin |
| 197 | + |
| 198 | + #define CALL_IF(NUM_BITS) \ |
| 199 | + else if (num_bits == NUM_BITS) { \ |
| 200 | + cudaFuncSetAttribute( \ |
| 201 | + marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \ |
| 202 | + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ |
| 203 | + marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \ |
| 204 | + <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ |
| 205 | + b_q_weight_ptr, out_ptr, size_k, size_n); \ |
| 206 | + } |
| 207 | + |
| 208 | +torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, |
| 209 | + int64_t size_n, int64_t num_bits) { |
| 210 | + // Verify compatibility with marlin tile of 16x64 |
| 211 | + TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, |
| 212 | + " is not divisible by tile_k_size = ", marlin::tile_k_size); |
| 213 | + TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n, |
| 214 | + " is not divisible by tile_n_size = ", marlin::tile_n_size); |
| 215 | + |
| 216 | + TORCH_CHECK(num_bits == 4 || num_bits == 8, |
| 217 | + "num_bits must be 4 or 8. Got = ", num_bits); |
| 218 | + int const pack_factor = 32 / num_bits; |
| 219 | + |
| 220 | + // Verify B |
| 221 | + TORCH_CHECK(b_q_weight.size(0) == size_k, |
| 222 | + "b_q_weight.size(0) = ", b_q_weight.size(0), |
| 223 | + " is not size_k = ", size_k); |
| 224 | + TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1), |
| 225 | + "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), |
| 226 | + ", size_n = ", size_n, ", pack_factor = ", pack_factor); |
| 227 | + |
| 228 | + // Verify device and strides |
| 229 | + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); |
| 230 | + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); |
| 231 | + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); |
| 232 | + |
| 233 | + // Alloc buffers |
| 234 | + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); |
| 235 | + auto options = torch::TensorOptions() |
| 236 | + .dtype(b_q_weight.dtype()) |
| 237 | + .device(b_q_weight.device()); |
| 238 | + torch::Tensor out = torch::empty( |
| 239 | + {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, |
| 240 | + options); |
| 241 | + |
| 242 | + // Get ptrs |
| 243 | + uint32_t const* b_q_weight_ptr = |
| 244 | + reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr()); |
| 245 | + uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr()); |
| 246 | + |
| 247 | + // Get dev info |
| 248 | + int dev = b_q_weight.get_device(); |
| 249 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); |
| 250 | + int blocks; |
| 251 | + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); |
| 252 | + |
| 253 | + int max_shared_mem = 0; |
| 254 | + cudaDeviceGetAttribute(&max_shared_mem, |
| 255 | + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); |
| 256 | + TORCH_CHECK(max_shared_mem > 0); |
| 257 | + |
| 258 | + if (false) { |
| 259 | + } |
| 260 | + CALL_IF(4) |
| 261 | + CALL_IF(8) |
| 262 | + else { |
| 263 | + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); |
| 264 | + } |
| 265 | + |
| 266 | + return out; |
| 267 | +} |
| 268 | + |
| 269 | +#endif |
0 commit comments