@@ -231,7 +231,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
231
231
}
232
232
233
233
// Use UE4M3 by default.
234
- template <class Type , bool UE8M0_SF = false >
234
+ template <class Type , bool UE8M0_SF = false , bool SMALL_NUM_EXPERTS = false >
235
235
__global__ void
236
236
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
237
237
__launch_bounds__ (512 , 4 ) cvt_fp16_to_fp4(
@@ -240,58 +240,191 @@ cvt_fp16_to_fp4(
240
240
#endif
241
241
int32_t numRows, int32_t numCols, Type const * in, float const * SFScale,
242
242
uint32_t * out, uint32_t * SFout, uint32_t * input_offset_by_experts,
243
- uint32_t * output_scale_offset_by_experts, int n_experts) {
243
+ uint32_t * output_scale_offset_by_experts, int n_experts, bool low_latency ) {
244
244
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
245
245
using PackedVec = PackedVec<Type>;
246
246
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
247
247
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
248
248
static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD,
249
249
" Vec size is not matched." );
250
250
251
- // Input tensor row/col loops.
252
- for (int rowIdx = blockIdx .x ; rowIdx < numRows; rowIdx += gridDim .x ) {
253
- for (int colIdx = threadIdx .x ; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
254
- colIdx += blockDim .x ) {
255
- int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
256
- PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
257
- // Get the output tensor offset.
258
- // Same as inOffset because 8 elements are packed into one uint32_t.
259
- int64_t outOffset = inOffset;
260
- auto & out_pos = out[outOffset];
261
-
262
- // Find index within the experts.
263
- int rowIdx_in_expert = 0 ;
264
- int expert_idx = 0 ;
251
+ int tid = blockIdx .x * blockDim .x + threadIdx .x ;
252
+ int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
253
+
254
+ // Each global thread processes one element
255
+ for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
256
+ globalIdx += gridDim .x * blockDim .x ) {
257
+ // Calculate which row and column this global thread should process
258
+ int rowIdx = globalIdx / colsPerRow;
259
+ int colIdx = globalIdx % colsPerRow;
260
+
261
+ int64_t inOffset = rowIdx * colsPerRow + colIdx;
262
+ PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
263
+ // Get the output tensor offset.
264
+ // Same as inOffset because 8 elements are packed into one uint32_t.
265
+ int64_t outOffset = inOffset;
266
+ auto & out_pos = out[outOffset];
267
+
268
+ // Find index within the experts using different strategies based on expert
269
+ // count
270
+ int rowIdx_in_expert = 0 ;
271
+ int expert_idx = 0 ;
272
+
273
+ if constexpr (SMALL_NUM_EXPERTS) {
265
274
for (int i = 0 ; i < n_experts; i++) {
266
- if (rowIdx >= input_offset_by_experts[i] &&
267
- rowIdx < input_offset_by_experts[i + 1 ]) {
268
- rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
275
+ uint32_t current_offset = __ldca (&input_offset_by_experts[i]);
276
+ uint32_t next_offset = __ldca (&input_offset_by_experts[i + 1 ]);
277
+ if (rowIdx >= current_offset && rowIdx < next_offset) {
278
+ rowIdx_in_expert = rowIdx - current_offset;
269
279
expert_idx = i;
270
280
break ;
271
281
}
272
282
}
283
+ } else {
284
+ // Load input offsets into registers first, then do the computation.
285
+ // Local array size set to 17 because of register limit.
286
+ uint32_t local_offsets[17 ];
287
+ for (int chunk_start = 0 ; chunk_start < n_experts; chunk_start += 16 ) {
288
+ *reinterpret_cast <int4 *>(local_offsets) =
289
+ __ldca (reinterpret_cast <const int4 *>(
290
+ &input_offset_by_experts[chunk_start]));
291
+ *reinterpret_cast <int4 *>(local_offsets + 4 ) =
292
+ __ldca (reinterpret_cast <const int4 *>(
293
+ &input_offset_by_experts[chunk_start + 4 ]));
294
+ *reinterpret_cast <int4 *>(local_offsets + 8 ) =
295
+ __ldca (reinterpret_cast <const int4 *>(
296
+ &input_offset_by_experts[chunk_start + 8 ]));
297
+ *reinterpret_cast <int4 *>(local_offsets + 12 ) =
298
+ __ldca (reinterpret_cast <const int4 *>(
299
+ &input_offset_by_experts[chunk_start + 12 ]));
300
+ local_offsets[16 ] = __ldca (&input_offset_by_experts[chunk_start + 16 ]);
301
+
302
+ // Check against the 16 loaded offsets
303
+ #pragma unroll
304
+ for (int i = 0 ; i < 16 ; i++) {
305
+ if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1 ]) {
306
+ rowIdx_in_expert = rowIdx - local_offsets[i];
307
+ expert_idx = chunk_start + i;
308
+ break ;
309
+ }
310
+ }
311
+ }
312
+ }
313
+
314
+ // Get the global scaling factor, which will be applied to the SF.
315
+ // Note SFScale is the same as next GEMM's alpha, which is
316
+ // (448.f / (Alpha_A / 6.f)).
317
+ float const SFScaleVal = SFScale == nullptr ? 1 .0f : SFScale[expert_idx];
318
+
319
+ int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
320
+ // The actual output_scales dim is computed from the padded numCols.
321
+ int32_t numCols_padded = (numCols + factor - 1 ) / factor * factor;
322
+ int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4 ;
323
+ uint32_t * SFout_in_expert =
324
+ SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
325
+
326
+ auto sf_out =
327
+ cvt_quant_to_fp4_get_sf_out_offset<uint32_t ,
328
+ CVT_FP4_NUM_THREADS_PER_SF>(
329
+ rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
330
+
331
+ out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
332
+ }
333
+ #endif
334
+ }
273
335
274
- // Get the global scaling factor, which will be applied to the SF.
275
- // Note SFScale is the same as next GEMM's alpha, which is
276
- // (448.f / (Alpha_A / 6.f)).
277
- float const SFScaleVal = SFScale == nullptr ? 1 .0f : SFScale[expert_idx];
278
-
279
- int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
280
- // The actual output_scales dim is computed from the padded numCols.
281
- int32_t numCols_padded = (numCols + factor - 1 ) / factor * factor;
282
- int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4 ;
283
- uint32_t * SFout_in_expert =
284
- SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
285
-
286
- auto sf_out =
287
- cvt_quant_to_fp4_get_sf_out_offset<uint32_t ,
288
- CVT_FP4_NUM_THREADS_PER_SF>(
289
- rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
290
-
291
- out_pos =
292
- cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
336
+ // Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
337
+ template <class Type , bool UE8M0_SF = false , bool SMALL_NUM_EXPERTS = false >
338
+ __global__ void
339
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
340
+ __launch_bounds__ (1024 , 4 ) cvt_fp16_to_fp4(
341
+ #else
342
+ cvt_fp16_to_fp4 (
343
+ #endif
344
+ int32_t numRows, int32_t numCols, Type const * in, float const * SFScale,
345
+ uint32_t * out, uint32_t * SFout, uint32_t * input_offset_by_experts,
346
+ uint32_t * output_scale_offset_by_experts, int n_experts) {
347
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
348
+ using PackedVec = PackedVec<Type>;
349
+ static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
350
+ (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
351
+ static_assert (sizeof (PackedVec) == sizeof (Type) * CVT_FP4_ELTS_PER_THREAD,
352
+ " Vec size is not matched." );
353
+ extern __shared__ uint32_t shared_input_offsets[];
354
+
355
+ // Load input offsets into shared memory.
356
+ // If n_experts is larger than 4, use vectorized int4 to save instructions.
357
+ // If n_experts is smaller than 4, read directly.
358
+ if constexpr (SMALL_NUM_EXPERTS) {
359
+ for (int i = threadIdx .x ; i < n_experts + 1 ; i += blockDim .x ) {
360
+ shared_input_offsets[i] = input_offset_by_experts[i];
361
+ }
362
+ } else {
363
+ for (int i = threadIdx .x * 4 ; i < n_experts; i += blockDim .x * 4 ) {
364
+ *reinterpret_cast <int4 *>(&shared_input_offsets[i]) =
365
+ *reinterpret_cast <const int4 *>(&input_offset_by_experts[i]);
366
+ }
367
+ if (threadIdx .x == 0 ) {
368
+ shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
293
369
}
294
370
}
371
+
372
+ __syncthreads ();
373
+
374
+ int tid = blockIdx .x * blockDim .x + threadIdx .x ;
375
+ int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
376
+
377
+ // Each global thread processes one element
378
+ for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
379
+ globalIdx += gridDim .x * blockDim .x ) {
380
+ // Calculate which row and column this global thread should process
381
+ int rowIdx = globalIdx / colsPerRow;
382
+ int colIdx = globalIdx % colsPerRow;
383
+
384
+ int64_t inOffset = rowIdx * colsPerRow + colIdx;
385
+ PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
386
+ int64_t outOffset = inOffset;
387
+ auto & out_pos = out[outOffset];
388
+
389
+ // Find expert using binary search for better performance with large m_topk
390
+ int rowIdx_in_expert = 0 ;
391
+ int expert_idx = 0 ;
392
+
393
+ // Binary search through experts using shared memory
394
+ int left = 0 , right = n_experts - 1 ;
395
+ while (left <= right) {
396
+ int mid = (left + right) / 2 ;
397
+ // Get offsets: shared_input_offsets[i] corresponds to
398
+ // input_offset_by_experts[i]
399
+ uint32_t mid_offset = shared_input_offsets[mid];
400
+ uint32_t next_offset = shared_input_offsets[mid + 1 ];
401
+
402
+ if (rowIdx >= mid_offset && rowIdx < next_offset) {
403
+ rowIdx_in_expert = rowIdx - mid_offset;
404
+ expert_idx = mid;
405
+ break ;
406
+ } else if (rowIdx < mid_offset) {
407
+ right = mid - 1 ;
408
+ } else {
409
+ left = mid + 1 ;
410
+ }
411
+ }
412
+
413
+ float const SFScaleVal = SFScale == nullptr ? 1 .0f : SFScale[expert_idx];
414
+
415
+ int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
416
+ int32_t numCols_padded = (numCols + factor - 1 ) / factor * factor;
417
+ int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4 ;
418
+ uint32_t * SFout_in_expert =
419
+ SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
420
+
421
+ auto sf_out =
422
+ cvt_quant_to_fp4_get_sf_out_offset<uint32_t ,
423
+ CVT_FP4_NUM_THREADS_PER_SF>(
424
+ rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
425
+
426
+ out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
427
+ }
295
428
#endif
296
429
}
297
430
@@ -309,18 +442,63 @@ void quant_impl(void* output, void* output_scale, void* input,
309
442
310
443
// Grid, Block size.
311
444
// Each thread converts 8 values.
312
- dim3 block (std::min (int (k / ELTS_PER_THREAD), 512 ));
445
+ int const workSizePerRow = k / ELTS_PER_THREAD;
446
+ int const totalWorkSize = m_topk * workSizePerRow;
447
+ dim3 block (std::min (workSizePerRow, 512 ));
313
448
// Get number of blocks per SM (assume we can fully utilize the SM).
314
449
int const numBlocksPerSM = 2048 / block.x ;
315
- dim3 grid (std::min (int (m_topk), multiProcessorCount * numBlocksPerSM));
316
-
317
- cvt_fp16_to_fp4<T, false ><<<grid, block, 0 , stream>>> (
318
- m_topk, k, reinterpret_cast <T*>(input),
319
- reinterpret_cast <float *>(input_global_scale),
320
- reinterpret_cast <uint32_t *>(output),
321
- reinterpret_cast <uint32_t *>(output_scale),
322
- reinterpret_cast <uint32_t *>(input_offset_by_experts),
323
- reinterpret_cast <uint32_t *>(output_scale_offset_by_experts), n_experts);
450
+ dim3 grid (std::min (static_cast <int >((totalWorkSize + block.x - 1 ) / block.x ),
451
+ multiProcessorCount * numBlocksPerSM));
452
+ while (grid.x <= multiProcessorCount && block.x > 64 ) {
453
+ grid.x *= 2 ;
454
+ block.x = (block.x + 1 ) / 2 ;
455
+ }
456
+
457
+ int const blockRepeat =
458
+ (totalWorkSize + block.x * grid.x - 1 ) / (block.x * grid.x );
459
+ if (blockRepeat > 1 ) {
460
+ size_t shared_mem_size = (n_experts + 1 ) * sizeof (uint32_t );
461
+ if (n_experts >= 4 ) {
462
+ cvt_fp16_to_fp4<T, false , false >
463
+ <<<grid, block, shared_mem_size, stream>>> (
464
+ m_topk, k, reinterpret_cast <T*>(input),
465
+ reinterpret_cast <float *>(input_global_scale),
466
+ reinterpret_cast <uint32_t *>(output),
467
+ reinterpret_cast <uint32_t *>(output_scale),
468
+ reinterpret_cast <uint32_t *>(input_offset_by_experts),
469
+ reinterpret_cast <uint32_t *>(output_scale_offset_by_experts),
470
+ n_experts);
471
+ } else {
472
+ cvt_fp16_to_fp4<T, false , true ><<<grid, block, shared_mem_size, stream>>> (
473
+ m_topk, k, reinterpret_cast <T*>(input),
474
+ reinterpret_cast <float *>(input_global_scale),
475
+ reinterpret_cast <uint32_t *>(output),
476
+ reinterpret_cast <uint32_t *>(output_scale),
477
+ reinterpret_cast <uint32_t *>(input_offset_by_experts),
478
+ reinterpret_cast <uint32_t *>(output_scale_offset_by_experts),
479
+ n_experts);
480
+ }
481
+ } else {
482
+ if (n_experts >= 16 ) {
483
+ cvt_fp16_to_fp4<T, false , false ><<<grid, block, 0 , stream>>> (
484
+ m_topk, k, reinterpret_cast <T*>(input),
485
+ reinterpret_cast <float *>(input_global_scale),
486
+ reinterpret_cast <uint32_t *>(output),
487
+ reinterpret_cast <uint32_t *>(output_scale),
488
+ reinterpret_cast <uint32_t *>(input_offset_by_experts),
489
+ reinterpret_cast <uint32_t *>(output_scale_offset_by_experts),
490
+ n_experts, /* bool low_latency */ true );
491
+ } else {
492
+ cvt_fp16_to_fp4<T, false , true ><<<grid, block, 0 , stream>>> (
493
+ m_topk, k, reinterpret_cast <T*>(input),
494
+ reinterpret_cast <float *>(input_global_scale),
495
+ reinterpret_cast <uint32_t *>(output),
496
+ reinterpret_cast <uint32_t *>(output_scale),
497
+ reinterpret_cast <uint32_t *>(input_offset_by_experts),
498
+ reinterpret_cast <uint32_t *>(output_scale_offset_by_experts),
499
+ n_experts, /* bool low_latency */ true );
500
+ }
501
+ }
324
502
}
325
503
326
504
/* Quantization entry for fp4 experts quantization*/
0 commit comments