30
30
DeepGemmFusedMoE
31
31
from tensorrt_llm ._torch .modules .fused_moe .fused_moe_wide_ep import \
32
32
AlltoallMethodType
33
+ from tensorrt_llm ._torch .modules .fused_moe .interface import MoEWeightLoadingMode
33
34
from tensorrt_llm ._torch .modules .gated_mlp import GatedMLP
34
35
from tensorrt_llm ._utils import mpi_rank
35
36
from tensorrt_llm .mapping import Mapping
@@ -561,20 +562,22 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor,
561
562
562
563
@skip_non_hopper_unittest
563
564
@pytest .mark .parametrize (
564
- "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls" ,
565
+ "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode " ,
565
566
product (
566
567
[torch .bfloat16 ],
567
568
[72 ],
568
569
[128 , 256 , 384 , 512 , 1024 , 2048 , 4096 , 8192 ],
569
570
[2560 ],
570
571
[DefaultMoeRoutingMethod ],
572
+ [MoEWeightLoadingMode .VANILLA , MoEWeightLoadingMode .FUSED_GATE_UP_PROJ ],
571
573
),
572
574
)
573
575
def test_fused_moe_fp8_blockwise (dtype ,
574
576
num_experts ,
575
577
seq_len ,
576
578
hidden_size ,
577
579
RoutingMethodCls ,
580
+ WeightLoadingMode ,
578
581
mapping = None ):
579
582
SEQ_LEN = seq_len
580
583
HIDDEN_SIZE = hidden_size
@@ -600,6 +603,13 @@ def test_fused_moe_fp8_blockwise(dtype,
600
603
device = "cuda" )
601
604
602
605
weights = {}
606
+
607
+ if WeightLoadingMode == MoEWeightLoadingMode .FUSED_GATE_UP_PROJ :
608
+ weights ['gate_up_proj' ] = {}
609
+ weights ['down_proj' ] = {}
610
+ weights ['gate_up_proj_weight_scale' ] = {}
611
+ weights ['down_proj_weight_scale' ] = {}
612
+
603
613
for expert_id in range (NUM_EXPERTS ):
604
614
w1_weight = torch .randn ((INTERMEDIATE_SIZE , HIDDEN_SIZE ),
605
615
dtype = dtype ,
@@ -626,13 +636,26 @@ def test_fused_moe_fp8_blockwise(dtype,
626
636
weights [f"{ expert_id } .w1.weight" ] = w1_weight_fp8
627
637
weights [f"{ expert_id } .w2.weight" ] = w2_weight_fp8
628
638
weights [f"{ expert_id } .w3.weight" ] = w3_weight_fp8
629
- weights [f"{ expert_id } .w1.weight_scale_inv" ] = w1_weight_scale
630
- weights [f"{ expert_id } .w2.weight_scale_inv" ] = w2_weight_scale
631
- weights [f"{ expert_id } .w3.weight_scale_inv" ] = w3_weight_scale
632
639
weights [f"{ expert_id } .w1.weight_scale" ] = w1_weight_scale
633
640
weights [f"{ expert_id } .w2.weight_scale" ] = w2_weight_scale
634
641
weights [f"{ expert_id } .w3.weight_scale" ] = w3_weight_scale
635
642
643
+ if WeightLoadingMode == MoEWeightLoadingMode .FUSED_GATE_UP_PROJ :
644
+ weights ['gate_up_proj' ][expert_id ] = torch .cat (
645
+ [w3_weight_fp8 , w1_weight_fp8 ],
646
+ dim = - 2 ).transpose (0 , 1 ).contiguous ()
647
+ weights ['down_proj' ][expert_id ] = w2_weight_fp8 .transpose (
648
+ 0 , 1 ).contiguous ()
649
+ weights ['gate_up_proj_weight_scale' ][expert_id ] = torch .cat (
650
+ [w3_weight_scale , w1_weight_scale ],
651
+ dim = - 2 ).transpose (0 , 1 ).contiguous ()
652
+ weights ['down_proj_weight_scale' ][
653
+ expert_id ] = w2_weight_scale .transpose (0 , 1 ).contiguous ()
654
+ elif WeightLoadingMode == MoEWeightLoadingMode .VANILLA :
655
+ weights [f"{ expert_id } .w1.weight_scale_inv" ] = w1_weight_scale
656
+ weights [f"{ expert_id } .w2.weight_scale_inv" ] = w2_weight_scale
657
+ weights [f"{ expert_id } .w3.weight_scale_inv" ] = w3_weight_scale
658
+
636
659
quant_config = QuantConfig (quant_algo = QuantAlgo .FP8_BLOCK_SCALES )
637
660
638
661
fused_moe = CuteDslFusedMoE (
@@ -643,6 +666,7 @@ def test_fused_moe_fp8_blockwise(dtype,
643
666
dtype = dtype ,
644
667
reduce_results = True ,
645
668
model_config = ModelConfig (quant_config = quant_config , mapping = mapping ),
669
+ weight_loading_mode = WeightLoadingMode ,
646
670
)
647
671
fused_moe .cuda ()
648
672
fused_moe .load_weights ([weights ])
@@ -655,6 +679,7 @@ def test_fused_moe_fp8_blockwise(dtype,
655
679
dtype = dtype ,
656
680
reduce_results = True ,
657
681
model_config = ModelConfig (quant_config = quant_config , mapping = mapping ),
682
+ weight_loading_mode = WeightLoadingMode ,
658
683
)
659
684
fused_moe_origin .cuda ()
660
685
fused_moe_origin .load_weights ([weights ])
0 commit comments