1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
import dataclasses
4
+ from math import prod
4
5
from typing import Optional
5
6
6
7
import pytest
7
8
import torch
8
9
9
10
from vllm import _custom_ops as ops
10
11
from vllm .config import ParallelConfig , VllmConfig , set_current_vllm_config
11
- from vllm .model_executor .layers .fused_moe .cutlass_moe import cutlass_moe_fp8
12
+ from vllm .model_executor .layers .fused_moe .cutlass_moe import (
13
+ cutlass_moe_fp8 , run_cutlass_moe_fp8 )
12
14
from vllm .model_executor .layers .fused_moe .fused_moe import (fused_experts ,
13
15
fused_topk )
16
+ from vllm .model_executor .layers .fused_moe .utils import (
17
+ moe_kernel_quantize_input )
14
18
from vllm .platforms import current_platform
15
19
16
20
NUM_EXPERTS = [40 , 64 ]
@@ -236,6 +240,7 @@ def test_cutlass_moe_8_bit_no_graph(
236
240
per_act_token : bool ,
237
241
per_out_ch : bool ,
238
242
monkeypatch ,
243
+ ep_size : Optional [int ] = None ,
239
244
):
240
245
current_platform .seed_everything (7 )
241
246
monkeypatch .setenv ("VLLM_FUSED_MOE_CHUNK_SIZE" , "8192" )
@@ -254,7 +259,13 @@ def test_cutlass_moe_8_bit_no_graph(
254
259
triton_output = fused_experts (mt .a_d , mt .w1_d , mt .w2_d , topk_weights ,
255
260
topk_ids )
256
261
257
- cutlass_output = run_8_bit (mt , topk_weights , topk_ids , per_act_token )
262
+ if ep_size is not None :
263
+ assert e % ep_size == 0 , "Cannot distribute experts evenly"
264
+ number_local_experts = e // ep_size
265
+ else :
266
+ number_local_experts = None
267
+ cutlass_output = run_8_bit (mt , topk_weights , topk_ids , per_act_token ,
268
+ number_local_experts )
258
269
259
270
# Note 5.5 only needed for larger problem sizes, 5 works ok for
260
271
# the rest.
@@ -340,9 +351,62 @@ def test_cutlass_moe_8_bit_EP(
340
351
per_out_channel : bool ,
341
352
ep_size : int ,
342
353
monkeypatch ,
354
+ ):
355
+ test_cutlass_moe_8_bit_no_graph (m , n , k , e , topk , per_act_token ,
356
+ per_out_channel , monkeypatch , ep_size )
357
+
358
+
359
+ LARGE_MNK_FACTORS = [
360
+ (1 , 8192 , 5120 , 31 ),
361
+ (32768 , 1024 , 1024 , 16 ),
362
+ (65536 , 512 , 1024 , 16 ),
363
+ ]
364
+
365
+
366
+ @pytest .mark .parametrize ("m,n,k,topk" , LARGE_MNK_FACTORS )
367
+ @pytest .mark .parametrize ("e" , [128 ])
368
+ @pytest .mark .parametrize ("per_act_token" , [False ])
369
+ @pytest .mark .parametrize ("per_out_channel" , [True ])
370
+ @pytest .mark .parametrize ("ep_size" , [8 ])
371
+ @pytest .mark .skipif (
372
+ (lambda x : x is None or not ops .cutlass_group_gemm_supported (x .to_int ()))(
373
+ current_platform .get_device_capability ()),
374
+ reason = "Grouped gemm is not supported on this GPU type." )
375
+ def test_cutlass_moe_8_bit_EP_large (
376
+ m : int ,
377
+ n : int ,
378
+ k : int ,
379
+ e : int ,
380
+ topk : int ,
381
+ per_act_token : bool ,
382
+ per_out_channel : bool ,
383
+ ep_size : int ,
384
+ monkeypatch ,
385
+ ):
386
+ test_cutlass_moe_8_bit_no_graph (m , n , k , e , topk , per_act_token ,
387
+ per_out_channel , monkeypatch , ep_size )
388
+
389
+
390
+ @pytest .mark .parametrize ("m,n,k,topk" , [(1 , 8192 , 5120 , 31 )])
391
+ @pytest .mark .parametrize ("e" , [128 ])
392
+ @pytest .mark .parametrize ("per_act_token" , [False ])
393
+ @pytest .mark .parametrize ("per_out_channel" , [True ])
394
+ @pytest .mark .parametrize ("ep_size" , [8 ])
395
+ @pytest .mark .skipif (
396
+ (lambda x : x is None or not ops .cutlass_group_gemm_supported (x .to_int ()))(
397
+ current_platform .get_device_capability ()),
398
+ reason = "Grouped gemm is not supported on this GPU type." )
399
+ def test_run_cutlass_moe_fp8 (
400
+ m : int ,
401
+ n : int ,
402
+ k : int ,
403
+ e : int ,
404
+ topk : int ,
405
+ per_act_token : bool ,
406
+ per_out_channel : bool ,
407
+ ep_size : int ,
343
408
):
344
409
current_platform .seed_everything (7 )
345
- monkeypatch .setenv ("VLLM_FUSED_MOE_CHUNK_SIZE" , "8192" )
346
410
with set_current_vllm_config (vllm_config ):
347
411
mt = MOETensors8Bit .make_moe_tensors_8bit (m , k , n , e , per_act_token ,
348
412
per_out_channel )
@@ -352,20 +416,53 @@ def test_cutlass_moe_8_bit_EP(
352
416
score ,
353
417
topk ,
354
418
renormalize = False )
355
-
356
- # Note that we are using the dequantized versions of the tensors.
357
- # Using a, w1 and w2 directly results in minor output differences.
358
- triton_output = fused_experts (mt .a_d , mt .w1_d , mt .w2_d , topk_weights ,
359
- topk_ids )
360
-
361
- assert e % ep_size == 0 , "Cannot distribute experts evenly"
362
- cutlass_output = run_8_bit (mt ,
363
- topk_weights ,
364
- topk_ids ,
365
- per_act_token ,
366
- num_local_experts = e // ep_size )
367
-
368
- torch .testing .assert_close (triton_output ,
369
- cutlass_output ,
370
- atol = 5e-2 ,
371
- rtol = 1e-2 )
419
+ # we want to make sure there is at least one token that's generated in
420
+ # this expert shard and at least one token that's NOT generated in this
421
+ # expert shard
422
+ topk_ids [0 ][0 ] = - 1
423
+ topk_ids [0 ][1 ] = 1
424
+
425
+ workspace13_shape = (m * topk , max (2 * n , k ))
426
+ workspace2_shape = (m * topk , n )
427
+ output_shape = (m * topk , k )
428
+
429
+ workspace13 = torch .empty (prod (workspace13_shape ),
430
+ device = "cuda" ,
431
+ dtype = mt .a .dtype )
432
+ workspace2 = torch .empty (prod (workspace2_shape ),
433
+ device = "cuda" ,
434
+ dtype = mt .a .dtype )
435
+
436
+ num_local_experts = e // ep_size
437
+ start , end = 0 , num_local_experts
438
+ expert_map = [- 1 ] * e
439
+ expert_map [start :end ] = list (range (num_local_experts ))
440
+ expert_map = torch .tensor (expert_map , dtype = torch .int32 , device = "cuda" )
441
+
442
+ activation = lambda o , i : torch .ops ._C .silu_and_mul (o , i )
443
+ a1q , a1q_scale = moe_kernel_quantize_input (mt .a , mt .a_scale ,
444
+ torch .float8_e4m3fn ,
445
+ per_act_token )
446
+ global_num_experts = - 1 if mt .w1_q is None else mt .w1_q .size (0 )
447
+ func = lambda output : run_cutlass_moe_fp8 (
448
+ output , a1q , mt .w1_q , mt .w2_q , topk_ids , activation ,
449
+ global_num_experts , expert_map , mt .w1_scale , mt .w2_scale ,
450
+ a1q_scale , None , workspace13 , workspace2 , None , mt .a .dtype ,
451
+ per_act_token , per_out_channel , False )
452
+
453
+ workspace13 .random_ ()
454
+ output_random_workspace = torch .empty (output_shape ,
455
+ device = "cuda" ,
456
+ dtype = mt .a .dtype )
457
+ func (output_random_workspace )
458
+
459
+ workspace13 .fill_ (0 )
460
+ output_zero_workspace = torch .zeros (output_shape ,
461
+ device = "cuda" ,
462
+ dtype = mt .a .dtype )
463
+ func (output_zero_workspace )
464
+
465
+ torch .testing .assert_close (output_random_workspace ,
466
+ output_zero_workspace ,
467
+ atol = 5e-3 ,
468
+ rtol = 1e-3 )
0 commit comments