Skip to content

Commit 5384cce

Browse files
ProExpertProggarg-amit
authored andcommitted
[Kernel] AQ AZP 3/4: Asymmetric quantization kernels (vllm-project#7270)
Signed-off-by: Amit Garg <[email protected]>
1 parent aa5d84d commit 5384cce

File tree

9 files changed

+339
-57
lines changed

9 files changed

+339
-57
lines changed

csrc/cpu/quant.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
257257
// static-per-tensor quantization.
258258
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
259259
const torch::Tensor& input, // [..., hidden_size]
260-
const torch::Tensor& scale) {
260+
const torch::Tensor& scale,
261+
c10::optional<torch::Tensor> const& azp) {
261262
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
262263
TORCH_CHECK(input.is_contiguous());
263264
TORCH_CHECK(out.is_contiguous());
264265
TORCH_CHECK(scale.numel() == 1);
266+
TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");
265267

266268
const int hidden_size = input.size(-1);
267269
const int num_tokens = input.numel() / hidden_size;
@@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
277279
void dynamic_scaled_int8_quant(
278280
torch::Tensor& out, // [..., hidden_size]
279281
const torch::Tensor& input, // [..., hidden_size]
280-
torch::Tensor& scale // [..., 1]
281-
) {
282+
torch::Tensor& scale, // [..., 1]
283+
c10::optional<torch::Tensor> const& azp) {
282284
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
283285
TORCH_CHECK(input.is_contiguous());
284286
TORCH_CHECK(out.is_contiguous());
287+
TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");
285288

286289
int const hidden_size = input.size(-1);
287290
int const num_tokens = input.numel() / hidden_size;

csrc/cpu/torch_bindings.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
9494
#ifdef __AVX512F__
9595
// Compute int8 quantized tensor for given scaling factor.
9696
ops.def(
97-
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
98-
"()");
97+
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
98+
"Tensor? azp) -> ()");
9999
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
100+
100101
// Compute int8 quantized tensor and scaling factor
101102
ops.def(
102-
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
103-
"()");
103+
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
104+
"Tensor!? azp) -> ()");
104105
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
105106
&dynamic_scaled_int8_quant);
106107
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column

csrc/ops.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
184184
#endif
185185

186186
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
187-
torch::Tensor const& scale);
187+
torch::Tensor const& scale,
188+
c10::optional<torch::Tensor> const& azp);
188189

189190
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
190-
torch::Tensor& scales);
191+
torch::Tensor& scales,
192+
c10::optional<torch::Tensor> const& azp);
191193

192194
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
193195
torch::Tensor b_gptq_qzeros,

csrc/quantization/compressed_tensors/int8_quant_kernels.cu

Lines changed: 160 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414

1515
static inline __device__ int8_t float_to_int8_rn(float x) {
1616
#ifdef USE_ROCM
17-
static const float i8_min =
17+
static constexpr auto i8_min =
1818
static_cast<float>(std::numeric_limits<int8_t>::min());
19-
static const float i8_max =
19+
static constexpr auto i8_max =
2020
static_cast<float>(std::numeric_limits<int8_t>::max());
21-
// round
21+
22+
// To match the rounding mode of CUDA, we use nearbyint.
23+
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
24+
// If that changes in the future, we may need to set the rounding mode
25+
// explicitly, either at runtime or compile time.
2226
float dst = std::nearbyint(x);
27+
2328
// saturate
2429
dst = std::clamp(dst, i8_min, i8_max);
2530
return static_cast<int8_t>(dst);
@@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
3136
#endif
3237
}
3338

39+
static inline __device__ int32_t float_to_int32_rn(float x) {
40+
#ifdef USE_ROCM
41+
// int32_max is not exactly representable as float.
42+
// Therefore, we need to be careful and manually return int32_max on overflow.
43+
// For symmetry, we also do the same for int32_min, even though it is exactly
44+
// representable as float and the conversion should be exact.
45+
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
46+
static constexpr auto i32_min_f = static_cast<float>(i32_min);
47+
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
48+
static constexpr auto i32_max_f = static_cast<float>(i32_max);
49+
50+
// To match the rounding mode of CUDA, we use nearbyint.
51+
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
52+
// If that changes in the future, we may need to set the rounding mode
53+
// explicitly, either at runtime or compile time.
54+
float dst = std::nearbyint(x);
55+
56+
// saturate on the higher end.
57+
if (dst >= i32_max_f) {
58+
return i32_max;
59+
}
60+
// saturate on the lower end.
61+
if (dst <= i32_min_f) {
62+
return i32_min;
63+
}
64+
65+
return static_cast<int32_t>(dst);
66+
#else
67+
// CUDA path
68+
uint32_t dst;
69+
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
70+
return reinterpret_cast<const int32_t&>(dst);
71+
#endif
72+
}
73+
74+
static inline __device__ int8_t int32_to_int8(int32_t x) {
75+
#ifdef USE_ROCM
76+
static constexpr auto i8_min =
77+
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
78+
static constexpr auto i8_max =
79+
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
80+
81+
// saturate
82+
int32_t dst = std::clamp(x, i8_min, i8_max);
83+
return static_cast<int8_t>(dst);
84+
#else
85+
// CUDA path
86+
uint32_t dst;
87+
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
88+
return reinterpret_cast<const int8_t&>(dst);
89+
#endif
90+
}
91+
3492
namespace vllm {
3593

3694
template <typename scalar_t, typename scale_type>
@@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel(
47105
}
48106
}
49107

108+
template <typename scalar_t, typename scale_type, typename azp_type>
109+
__global__ void static_scaled_int8_azp_quant_kernel(
110+
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
111+
scale_type const* scale_ptr, azp_type const* azp_ptr,
112+
const int hidden_size) {
113+
int const tid = threadIdx.x;
114+
int const token_idx = blockIdx.x;
115+
scale_type const scale = *scale_ptr;
116+
azp_type const azp = *azp_ptr;
117+
118+
for (int i = tid; i < hidden_size; i += blockDim.x) {
119+
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
120+
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
121+
out[token_idx * hidden_size + i] = quant_val;
122+
}
123+
}
124+
50125
template <typename scalar_t, typename scale_type>
51126
__global__ void dynamic_scaled_int8_quant_kernel(
52127
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
@@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
80155
}
81156
}
82157

158+
template <typename scalar_t, typename scale_type, typename azp_type>
159+
__global__ void dynamic_scaled_int8_azp_quant_kernel(
160+
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
161+
scale_type* scale, azp_type* azp, const int hidden_size) {
162+
int const token_idx = blockIdx.x;
163+
164+
// Scan for the min and max value for this token
165+
float max_val = std::numeric_limits<float>::min();
166+
float min_val = std::numeric_limits<float>::max();
167+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
168+
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
169+
max_val = std::max(max_val, val);
170+
min_val = std::min(min_val, val);
171+
}
172+
173+
// Reduce the max and min values across the block
174+
using BlockReduce = cub::BlockReduce<float, 1024>;
175+
__shared__ typename BlockReduce::TempStorage reduceStorage;
176+
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
177+
__syncthreads(); // Make sure min doesn't mess with max shared memory
178+
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
179+
180+
__shared__ scale_type scale_sh;
181+
__shared__ azp_type azp_sh;
182+
183+
// Compute the scale and zero point and store them, only on the first thread
184+
if (threadIdx.x == 0) {
185+
float const scale_val = (max_val - min_val) / 255.0f;
186+
// Use rounding to even (same as torch.round)
187+
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
188+
auto const azp_val = static_cast<azp_type>(azp_float);
189+
190+
// Store the scale and azp into shared and global
191+
scale[token_idx] = scale_sh = scale_val;
192+
azp[token_idx] = azp_sh = azp_val;
193+
}
194+
195+
// Wait for the scale and azp to be computed
196+
__syncthreads();
197+
198+
float const scale_val = scale_sh;
199+
azp_type const azp_val = azp_sh;
200+
201+
// Quantize the values
202+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
203+
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
204+
auto const quant_val =
205+
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
206+
out[token_idx * hidden_size + i] = quant_val;
207+
}
208+
}
209+
83210
} // namespace vllm
84211

85212
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
86213
torch::Tensor const& input, // [..., hidden_size]
87-
torch::Tensor const& scale) {
214+
torch::Tensor const& scale,
215+
c10::optional<torch::Tensor> const& azp) {
88216
TORCH_CHECK(input.is_contiguous());
89217
TORCH_CHECK(out.is_contiguous());
90218
TORCH_CHECK(scale.numel() == 1);
219+
TORCH_CHECK(!azp || azp->numel() == 1);
91220

92221
int const hidden_size = input.size(-1);
93222
int const num_tokens = input.numel() / hidden_size;
@@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
96225
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
97226
VLLM_DISPATCH_FLOATING_TYPES(
98227
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
99-
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
100-
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
101-
out.data_ptr<int8_t>(),
102-
scale.data_ptr<float>(), hidden_size);
228+
if (!azp) {
229+
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
230+
<<<grid, block, 0, stream>>>(
231+
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
232+
scale.data_ptr<float>(), hidden_size);
233+
} else {
234+
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
235+
<<<grid, block, 0, stream>>>(
236+
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
237+
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
238+
hidden_size);
239+
}
103240
});
104241
}
105242

106243
void dynamic_scaled_int8_quant(
107244
torch::Tensor& out, // [..., hidden_size]
108245
torch::Tensor const& input, // [..., hidden_size]
109-
torch::Tensor& scales) {
246+
torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
110247
TORCH_CHECK(input.is_contiguous());
111248
TORCH_CHECK(out.is_contiguous());
249+
TORCH_CHECK(scales.is_contiguous());
250+
TORCH_CHECK(!azp || azp->is_contiguous());
112251

113252
int const hidden_size = input.size(-1);
114253
int const num_tokens = input.numel() / hidden_size;
@@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant(
117256
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
118257
VLLM_DISPATCH_FLOATING_TYPES(
119258
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
120-
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
121-
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
122-
out.data_ptr<int8_t>(),
123-
scales.data_ptr<float>(), hidden_size);
259+
if (!azp) {
260+
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
261+
<<<grid, block, 0, stream>>>(
262+
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
263+
scales.data_ptr<float>(), hidden_size);
264+
} else {
265+
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
266+
<<<grid, block, 0, stream>>>(
267+
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
268+
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
269+
hidden_size);
270+
}
124271
});
125272
}

csrc/torch_bindings.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
336336

337337
// Compute int8 quantized tensor for given scaling factor.
338338
ops.def(
339-
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
340-
"()");
339+
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
340+
"Tensor? azp) -> ()");
341341
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
342342

343343
// Compute int8 quantized tensor and scaling factor
344344
ops.def(
345-
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
346-
"()");
345+
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
346+
"Tensor!? azp) -> ()");
347347
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
348348
&dynamic_scaled_int8_quant);
349349
}

0 commit comments

Comments
 (0)