14
14
15
15
static inline __device__ int8_t float_to_int8_rn (float x) {
16
16
#ifdef USE_ROCM
17
- static const float i8_min =
17
+ static constexpr auto i8_min =
18
18
static_cast <float >(std::numeric_limits<int8_t >::min ());
19
- static const float i8_max =
19
+ static constexpr auto i8_max =
20
20
static_cast <float >(std::numeric_limits<int8_t >::max ());
21
- // round
21
+
22
+ // To match the rounding mode of CUDA, we use nearbyint.
23
+ // It uses the current rounding mode, which is always FE_TONEAREST on HIP.
24
+ // If that changes in the future, we may need to set the rounding mode
25
+ // explicitly, either at runtime or compile time.
22
26
float dst = std::nearbyint (x);
27
+
23
28
// saturate
24
29
dst = std::clamp (dst, i8_min, i8_max);
25
30
return static_cast <int8_t >(dst);
@@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
31
36
#endif
32
37
}
33
38
39
+ static inline __device__ int32_t float_to_int32_rn (float x) {
40
+ #ifdef USE_ROCM
41
+ // int32_max is not exactly representable as float.
42
+ // Therefore, we need to be careful and manually return int32_max on overflow.
43
+ // For symmetry, we also do the same for int32_min, even though it is exactly
44
+ // representable as float and the conversion should be exact.
45
+ static constexpr auto i32_min = std::numeric_limits<int32_t >::min ();
46
+ static constexpr auto i32_min_f = static_cast <float >(i32_min);
47
+ static constexpr auto i32_max = std::numeric_limits<int32_t >::max ();
48
+ static constexpr auto i32_max_f = static_cast <float >(i32_max);
49
+
50
+ // To match the rounding mode of CUDA, we use nearbyint.
51
+ // It uses the current rounding mode, which is always FE_TONEAREST on HIP.
52
+ // If that changes in the future, we may need to set the rounding mode
53
+ // explicitly, either at runtime or compile time.
54
+ float dst = std::nearbyint (x);
55
+
56
+ // saturate on the higher end.
57
+ if (dst >= i32_max_f) {
58
+ return i32_max;
59
+ }
60
+ // saturate on the lower end.
61
+ if (dst <= i32_min_f) {
62
+ return i32_min;
63
+ }
64
+
65
+ return static_cast <int32_t >(dst);
66
+ #else
67
+ // CUDA path
68
+ uint32_t dst;
69
+ asm volatile (" cvt.rni.sat.s32.f32 %0, %1;" : " =r" (dst) : " f" (x));
70
+ return reinterpret_cast <const int32_t &>(dst);
71
+ #endif
72
+ }
73
+
74
+ static inline __device__ int8_t int32_to_int8 (int32_t x) {
75
+ #ifdef USE_ROCM
76
+ static constexpr auto i8_min =
77
+ static_cast <int32_t >(std::numeric_limits<int8_t >::min ());
78
+ static constexpr auto i8_max =
79
+ static_cast <int32_t >(std::numeric_limits<int8_t >::max ());
80
+
81
+ // saturate
82
+ int32_t dst = std::clamp (x, i8_min, i8_max);
83
+ return static_cast <int8_t >(dst);
84
+ #else
85
+ // CUDA path
86
+ uint32_t dst;
87
+ asm volatile (" cvt.sat.s8.s32 %0, %1;" : " =r" (dst) : " r" (x));
88
+ return reinterpret_cast <const int8_t &>(dst);
89
+ #endif
90
+ }
91
+
34
92
namespace vllm {
35
93
36
94
template <typename scalar_t , typename scale_type>
@@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel(
47
105
}
48
106
}
49
107
108
+ template <typename scalar_t , typename scale_type, typename azp_type>
109
+ __global__ void static_scaled_int8_azp_quant_kernel (
110
+ scalar_t const * __restrict__ input, int8_t * __restrict__ out,
111
+ scale_type const * scale_ptr, azp_type const * azp_ptr,
112
+ const int hidden_size) {
113
+ int const tid = threadIdx .x ;
114
+ int const token_idx = blockIdx .x ;
115
+ scale_type const scale = *scale_ptr;
116
+ azp_type const azp = *azp_ptr;
117
+
118
+ for (int i = tid; i < hidden_size; i += blockDim .x ) {
119
+ auto const val = static_cast <float >(input[token_idx * hidden_size + i]);
120
+ auto const quant_val = int32_to_int8 (float_to_int32_rn (val / scale) + azp);
121
+ out[token_idx * hidden_size + i] = quant_val;
122
+ }
123
+ }
124
+
50
125
template <typename scalar_t , typename scale_type>
51
126
__global__ void dynamic_scaled_int8_quant_kernel (
52
127
scalar_t const * __restrict__ input, int8_t * __restrict__ out,
@@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
80
155
}
81
156
}
82
157
158
+ template <typename scalar_t , typename scale_type, typename azp_type>
159
+ __global__ void dynamic_scaled_int8_azp_quant_kernel (
160
+ scalar_t const * __restrict__ input, int8_t * __restrict__ out,
161
+ scale_type* scale, azp_type* azp, const int hidden_size) {
162
+ int const token_idx = blockIdx .x ;
163
+
164
+ // Scan for the min and max value for this token
165
+ float max_val = std::numeric_limits<float >::min ();
166
+ float min_val = std::numeric_limits<float >::max ();
167
+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
168
+ auto val = static_cast <float >(input[token_idx * hidden_size + i]);
169
+ max_val = std::max (max_val, val);
170
+ min_val = std::min (min_val, val);
171
+ }
172
+
173
+ // Reduce the max and min values across the block
174
+ using BlockReduce = cub::BlockReduce<float , 1024 >;
175
+ __shared__ typename BlockReduce::TempStorage reduceStorage;
176
+ max_val = BlockReduce (reduceStorage).Reduce (max_val, cub::Max{}, blockDim .x );
177
+ __syncthreads (); // Make sure min doesn't mess with max shared memory
178
+ min_val = BlockReduce (reduceStorage).Reduce (min_val, cub::Min{}, blockDim .x );
179
+
180
+ __shared__ scale_type scale_sh;
181
+ __shared__ azp_type azp_sh;
182
+
183
+ // Compute the scale and zero point and store them, only on the first thread
184
+ if (threadIdx .x == 0 ) {
185
+ float const scale_val = (max_val - min_val) / 255 .0f ;
186
+ // Use rounding to even (same as torch.round)
187
+ auto const azp_float = std::nearbyint (-128 .0f - min_val / scale_val);
188
+ auto const azp_val = static_cast <azp_type>(azp_float);
189
+
190
+ // Store the scale and azp into shared and global
191
+ scale[token_idx] = scale_sh = scale_val;
192
+ azp[token_idx] = azp_sh = azp_val;
193
+ }
194
+
195
+ // Wait for the scale and azp to be computed
196
+ __syncthreads ();
197
+
198
+ float const scale_val = scale_sh;
199
+ azp_type const azp_val = azp_sh;
200
+
201
+ // Quantize the values
202
+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
203
+ auto const val = static_cast <float >(input[token_idx * hidden_size + i]);
204
+ auto const quant_val =
205
+ int32_to_int8 (float_to_int32_rn (val / scale_val) + azp_val);
206
+ out[token_idx * hidden_size + i] = quant_val;
207
+ }
208
+ }
209
+
83
210
} // namespace vllm
84
211
85
212
void static_scaled_int8_quant (torch::Tensor& out, // [..., hidden_size]
86
213
torch::Tensor const & input, // [..., hidden_size]
87
- torch::Tensor const & scale) {
214
+ torch::Tensor const & scale,
215
+ c10::optional<torch::Tensor> const & azp) {
88
216
TORCH_CHECK (input.is_contiguous ());
89
217
TORCH_CHECK (out.is_contiguous ());
90
218
TORCH_CHECK (scale.numel () == 1 );
219
+ TORCH_CHECK (!azp || azp->numel () == 1 );
91
220
92
221
int const hidden_size = input.size (-1 );
93
222
int const num_tokens = input.numel () / hidden_size;
@@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
96
225
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
97
226
VLLM_DISPATCH_FLOATING_TYPES (
98
227
input.scalar_type (), " static_scaled_int8_quant_kernel" , [&] {
99
- vllm::static_scaled_int8_quant_kernel<scalar_t , float >
100
- <<<grid, block, 0 , stream>>> (input.data_ptr <scalar_t >(),
101
- out.data_ptr <int8_t >(),
102
- scale.data_ptr <float >(), hidden_size);
228
+ if (!azp) {
229
+ vllm::static_scaled_int8_quant_kernel<scalar_t , float >
230
+ <<<grid, block, 0 , stream>>> (
231
+ input.data_ptr <scalar_t >(), out.data_ptr <int8_t >(),
232
+ scale.data_ptr <float >(), hidden_size);
233
+ } else {
234
+ vllm::static_scaled_int8_azp_quant_kernel<scalar_t , float , int32_t >
235
+ <<<grid, block, 0 , stream>>> (
236
+ input.data_ptr <scalar_t >(), out.data_ptr <int8_t >(),
237
+ scale.data_ptr <float >(), azp->data_ptr <int32_t >(),
238
+ hidden_size);
239
+ }
103
240
});
104
241
}
105
242
106
243
void dynamic_scaled_int8_quant (
107
244
torch::Tensor& out, // [..., hidden_size]
108
245
torch::Tensor const & input, // [..., hidden_size]
109
- torch::Tensor& scales) {
246
+ torch::Tensor& scales, c10::optional<torch::Tensor> const & azp ) {
110
247
TORCH_CHECK (input.is_contiguous ());
111
248
TORCH_CHECK (out.is_contiguous ());
249
+ TORCH_CHECK (scales.is_contiguous ());
250
+ TORCH_CHECK (!azp || azp->is_contiguous ());
112
251
113
252
int const hidden_size = input.size (-1 );
114
253
int const num_tokens = input.numel () / hidden_size;
@@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant(
117
256
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
118
257
VLLM_DISPATCH_FLOATING_TYPES (
119
258
input.scalar_type (), " dynamic_scaled_int8_quant_kernel" , [&] {
120
- vllm::dynamic_scaled_int8_quant_kernel<scalar_t , float >
121
- <<<grid, block, 0 , stream>>> (input.data_ptr <scalar_t >(),
122
- out.data_ptr <int8_t >(),
123
- scales.data_ptr <float >(), hidden_size);
259
+ if (!azp) {
260
+ vllm::dynamic_scaled_int8_quant_kernel<scalar_t , float >
261
+ <<<grid, block, 0 , stream>>> (
262
+ input.data_ptr <scalar_t >(), out.data_ptr <int8_t >(),
263
+ scales.data_ptr <float >(), hidden_size);
264
+ } else {
265
+ vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t , float , int32_t >
266
+ <<<grid, block, 0 , stream>>> (
267
+ input.data_ptr <scalar_t >(), out.data_ptr <int8_t >(),
268
+ scales.data_ptr <float >(), azp->data_ptr <int32_t >(),
269
+ hidden_size);
270
+ }
124
271
});
125
272
}
0 commit comments