Skip to content

Commit 303a886

Browse files
jiahancyangw-dev
authored andcommitted
[Hardware][NVIDIA][kernel] Fp4 MOE quant kernel optimization (vllm-project#19500)
Signed-off-by: Yang Wang <[email protected]>
1 parent 6062ff6 commit 303a886

File tree

1 file changed

+226
-48
lines changed

1 file changed

+226
-48
lines changed

csrc/quantization/fp4/nvfp4_experts_quant.cu

Lines changed: 226 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
231231
}
232232

233233
// Use UE4M3 by default.
234-
template <class Type, bool UE8M0_SF = false>
234+
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
235235
__global__ void
236236
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
237237
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
@@ -240,58 +240,191 @@ cvt_fp16_to_fp4(
240240
#endif
241241
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
242242
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
243-
uint32_t* output_scale_offset_by_experts, int n_experts) {
243+
uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) {
244244
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
245245
using PackedVec = PackedVec<Type>;
246246
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
247247
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
248248
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
249249
"Vec size is not matched.");
250250

251-
// Input tensor row/col loops.
252-
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
253-
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
254-
colIdx += blockDim.x) {
255-
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
256-
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
257-
// Get the output tensor offset.
258-
// Same as inOffset because 8 elements are packed into one uint32_t.
259-
int64_t outOffset = inOffset;
260-
auto& out_pos = out[outOffset];
261-
262-
// Find index within the experts.
263-
int rowIdx_in_expert = 0;
264-
int expert_idx = 0;
251+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
252+
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
253+
254+
// Each global thread processes one element
255+
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
256+
globalIdx += gridDim.x * blockDim.x) {
257+
// Calculate which row and column this global thread should process
258+
int rowIdx = globalIdx / colsPerRow;
259+
int colIdx = globalIdx % colsPerRow;
260+
261+
int64_t inOffset = rowIdx * colsPerRow + colIdx;
262+
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
263+
// Get the output tensor offset.
264+
// Same as inOffset because 8 elements are packed into one uint32_t.
265+
int64_t outOffset = inOffset;
266+
auto& out_pos = out[outOffset];
267+
268+
// Find index within the experts using different strategies based on expert
269+
// count
270+
int rowIdx_in_expert = 0;
271+
int expert_idx = 0;
272+
273+
if constexpr (SMALL_NUM_EXPERTS) {
265274
for (int i = 0; i < n_experts; i++) {
266-
if (rowIdx >= input_offset_by_experts[i] &&
267-
rowIdx < input_offset_by_experts[i + 1]) {
268-
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
275+
uint32_t current_offset = __ldca(&input_offset_by_experts[i]);
276+
uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);
277+
if (rowIdx >= current_offset && rowIdx < next_offset) {
278+
rowIdx_in_expert = rowIdx - current_offset;
269279
expert_idx = i;
270280
break;
271281
}
272282
}
283+
} else {
284+
// Load input offsets into registers first, then do the computation.
285+
// Local array size set to 17 because of register limit.
286+
uint32_t local_offsets[17];
287+
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
288+
*reinterpret_cast<int4*>(local_offsets) =
289+
__ldca(reinterpret_cast<const int4*>(
290+
&input_offset_by_experts[chunk_start]));
291+
*reinterpret_cast<int4*>(local_offsets + 4) =
292+
__ldca(reinterpret_cast<const int4*>(
293+
&input_offset_by_experts[chunk_start + 4]));
294+
*reinterpret_cast<int4*>(local_offsets + 8) =
295+
__ldca(reinterpret_cast<const int4*>(
296+
&input_offset_by_experts[chunk_start + 8]));
297+
*reinterpret_cast<int4*>(local_offsets + 12) =
298+
__ldca(reinterpret_cast<const int4*>(
299+
&input_offset_by_experts[chunk_start + 12]));
300+
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
301+
302+
// Check against the 16 loaded offsets
303+
#pragma unroll
304+
for (int i = 0; i < 16; i++) {
305+
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
306+
rowIdx_in_expert = rowIdx - local_offsets[i];
307+
expert_idx = chunk_start + i;
308+
break;
309+
}
310+
}
311+
}
312+
}
313+
314+
// Get the global scaling factor, which will be applied to the SF.
315+
// Note SFScale is the same as next GEMM's alpha, which is
316+
// (448.f / (Alpha_A / 6.f)).
317+
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
318+
319+
int factor = CVT_FP4_SF_VEC_SIZE * 4;
320+
// The actual output_scales dim is computed from the padded numCols.
321+
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
322+
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
323+
uint32_t* SFout_in_expert =
324+
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
325+
326+
auto sf_out =
327+
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
328+
CVT_FP4_NUM_THREADS_PER_SF>(
329+
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
330+
331+
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
332+
}
333+
#endif
334+
}
273335

274-
// Get the global scaling factor, which will be applied to the SF.
275-
// Note SFScale is the same as next GEMM's alpha, which is
276-
// (448.f / (Alpha_A / 6.f)).
277-
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
278-
279-
int factor = CVT_FP4_SF_VEC_SIZE * 4;
280-
// The actual output_scales dim is computed from the padded numCols.
281-
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
282-
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
283-
uint32_t* SFout_in_expert =
284-
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
285-
286-
auto sf_out =
287-
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
288-
CVT_FP4_NUM_THREADS_PER_SF>(
289-
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
290-
291-
out_pos =
292-
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
336+
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
337+
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
338+
__global__ void
339+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
340+
__launch_bounds__(1024, 4) cvt_fp16_to_fp4(
341+
#else
342+
cvt_fp16_to_fp4(
343+
#endif
344+
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
345+
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
346+
uint32_t* output_scale_offset_by_experts, int n_experts) {
347+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
348+
using PackedVec = PackedVec<Type>;
349+
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
350+
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
351+
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
352+
"Vec size is not matched.");
353+
extern __shared__ uint32_t shared_input_offsets[];
354+
355+
// Load input offsets into shared memory.
356+
// If n_experts is larger than 4, use vectorized int4 to save instructions.
357+
// If n_experts is smaller than 4, read directly.
358+
if constexpr (SMALL_NUM_EXPERTS) {
359+
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
360+
shared_input_offsets[i] = input_offset_by_experts[i];
361+
}
362+
} else {
363+
for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
364+
*reinterpret_cast<int4*>(&shared_input_offsets[i]) =
365+
*reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
366+
}
367+
if (threadIdx.x == 0) {
368+
shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
293369
}
294370
}
371+
372+
__syncthreads();
373+
374+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
375+
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
376+
377+
// Each global thread processes one element
378+
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
379+
globalIdx += gridDim.x * blockDim.x) {
380+
// Calculate which row and column this global thread should process
381+
int rowIdx = globalIdx / colsPerRow;
382+
int colIdx = globalIdx % colsPerRow;
383+
384+
int64_t inOffset = rowIdx * colsPerRow + colIdx;
385+
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
386+
int64_t outOffset = inOffset;
387+
auto& out_pos = out[outOffset];
388+
389+
// Find expert using binary search for better performance with large m_topk
390+
int rowIdx_in_expert = 0;
391+
int expert_idx = 0;
392+
393+
// Binary search through experts using shared memory
394+
int left = 0, right = n_experts - 1;
395+
while (left <= right) {
396+
int mid = (left + right) / 2;
397+
// Get offsets: shared_input_offsets[i] corresponds to
398+
// input_offset_by_experts[i]
399+
uint32_t mid_offset = shared_input_offsets[mid];
400+
uint32_t next_offset = shared_input_offsets[mid + 1];
401+
402+
if (rowIdx >= mid_offset && rowIdx < next_offset) {
403+
rowIdx_in_expert = rowIdx - mid_offset;
404+
expert_idx = mid;
405+
break;
406+
} else if (rowIdx < mid_offset) {
407+
right = mid - 1;
408+
} else {
409+
left = mid + 1;
410+
}
411+
}
412+
413+
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
414+
415+
int factor = CVT_FP4_SF_VEC_SIZE * 4;
416+
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
417+
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
418+
uint32_t* SFout_in_expert =
419+
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
420+
421+
auto sf_out =
422+
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
423+
CVT_FP4_NUM_THREADS_PER_SF>(
424+
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
425+
426+
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
427+
}
295428
#endif
296429
}
297430

@@ -309,18 +442,63 @@ void quant_impl(void* output, void* output_scale, void* input,
309442

310443
// Grid, Block size.
311444
// Each thread converts 8 values.
312-
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
445+
int const workSizePerRow = k / ELTS_PER_THREAD;
446+
int const totalWorkSize = m_topk * workSizePerRow;
447+
dim3 block(std::min(workSizePerRow, 512));
313448
// Get number of blocks per SM (assume we can fully utilize the SM).
314449
int const numBlocksPerSM = 2048 / block.x;
315-
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
316-
317-
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
318-
m_topk, k, reinterpret_cast<T*>(input),
319-
reinterpret_cast<float*>(input_global_scale),
320-
reinterpret_cast<uint32_t*>(output),
321-
reinterpret_cast<uint32_t*>(output_scale),
322-
reinterpret_cast<uint32_t*>(input_offset_by_experts),
323-
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), n_experts);
450+
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
451+
multiProcessorCount * numBlocksPerSM));
452+
while (grid.x <= multiProcessorCount && block.x > 64) {
453+
grid.x *= 2;
454+
block.x = (block.x + 1) / 2;
455+
}
456+
457+
int const blockRepeat =
458+
(totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
459+
if (blockRepeat > 1) {
460+
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
461+
if (n_experts >= 4) {
462+
cvt_fp16_to_fp4<T, false, false>
463+
<<<grid, block, shared_mem_size, stream>>>(
464+
m_topk, k, reinterpret_cast<T*>(input),
465+
reinterpret_cast<float*>(input_global_scale),
466+
reinterpret_cast<uint32_t*>(output),
467+
reinterpret_cast<uint32_t*>(output_scale),
468+
reinterpret_cast<uint32_t*>(input_offset_by_experts),
469+
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
470+
n_experts);
471+
} else {
472+
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
473+
m_topk, k, reinterpret_cast<T*>(input),
474+
reinterpret_cast<float*>(input_global_scale),
475+
reinterpret_cast<uint32_t*>(output),
476+
reinterpret_cast<uint32_t*>(output_scale),
477+
reinterpret_cast<uint32_t*>(input_offset_by_experts),
478+
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
479+
n_experts);
480+
}
481+
} else {
482+
if (n_experts >= 16) {
483+
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
484+
m_topk, k, reinterpret_cast<T*>(input),
485+
reinterpret_cast<float*>(input_global_scale),
486+
reinterpret_cast<uint32_t*>(output),
487+
reinterpret_cast<uint32_t*>(output_scale),
488+
reinterpret_cast<uint32_t*>(input_offset_by_experts),
489+
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
490+
n_experts, /* bool low_latency */ true);
491+
} else {
492+
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
493+
m_topk, k, reinterpret_cast<T*>(input),
494+
reinterpret_cast<float*>(input_global_scale),
495+
reinterpret_cast<uint32_t*>(output),
496+
reinterpret_cast<uint32_t*>(output_scale),
497+
reinterpret_cast<uint32_t*>(input_offset_by_experts),
498+
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
499+
n_experts, /* bool low_latency */ true);
500+
}
501+
}
324502
}
325503

326504
/*Quantization entry for fp4 experts quantization*/

0 commit comments

Comments
 (0)