Skip to content

Commit 6da95f2

Browse files
authored
[None][feat] Add support for fused gate_up_proj scales for FP8 blockwise (#6496)
Signed-off-by: Aurelien Chartier <[email protected]>
1 parent 46df871 commit 6da95f2

File tree

2 files changed

+67
-29
lines changed

2 files changed

+67
-29
lines changed

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -528,31 +528,44 @@ def load_expert_all_weight_scale_fp8_block_scale(
528528
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
529529
dst_w2_weight_scale: torch.Tensor, device):
530530
for local_slot_id, expert_id in enumerate(load_expert_ids):
531-
w3_scale = load_weight_shard(
532-
weights[f"{expert_id}.w3.weight_scale_inv"],
533-
module.tp_size,
534-
module.tp_rank,
535-
TensorParallelMode.COLUMN,
536-
device=device)
537-
dst_w3_w1_weight_scale[local_slot_id][:dst_w3_w1_weight_scale.
538-
shape[-2] //
539-
2].copy_(w3_scale)
540-
w1_scale = load_weight_shard(
541-
weights[f"{expert_id}.w1.weight_scale_inv"],
542-
module.tp_size,
543-
module.tp_rank,
544-
TensorParallelMode.COLUMN,
545-
device=device)
546-
dst_w3_w1_weight_scale[local_slot_id][dst_w3_w1_weight_scale.
547-
shape[-2] //
548-
2:].copy_(w1_scale)
549-
w2_scale = load_weight_shard(
550-
weights[f"{expert_id}.w2.weight_scale_inv"],
551-
module.tp_size,
552-
module.tp_rank,
553-
TensorParallelMode.ROW,
554-
device=device)
555-
dst_w2_weight_scale[local_slot_id].copy_(w2_scale)
531+
if module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
532+
w3_scale = weights['gate_up_proj_weight_scale'][
533+
expert_id].transpose(0, 1).contiguous()
534+
w1_scale = None
535+
w2_scale = weights['down_proj_weight_scale'][
536+
expert_id].transpose(0, 1).contiguous()
537+
elif module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
538+
w3_scale = weights[f"{expert_id}.w3.weight_scale_inv"]
539+
w1_scale = weights[f"{expert_id}.w1.weight_scale_inv"]
540+
w2_scale = weights[f"{expert_id}.w2.weight_scale_inv"]
541+
else:
542+
raise NotImplementedError(
543+
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
544+
)
545+
546+
w3_w1_scale_shard = load_weight_shard(w3_scale,
547+
module.tp_size,
548+
module.tp_rank,
549+
TensorParallelMode.COLUMN,
550+
device=device)
551+
552+
if w1_scale is not None:
553+
w1_scale_shard = load_weight_shard(w1_scale,
554+
module.tp_size,
555+
module.tp_rank,
556+
TensorParallelMode.COLUMN,
557+
device=device)
558+
w3_w1_scale_shard = torch.cat(
559+
[w3_w1_scale_shard, w1_scale_shard], dim=-2)
560+
561+
dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard)
562+
563+
w2_scale_shard = load_weight_shard(w2_scale,
564+
module.tp_size,
565+
module.tp_rank,
566+
TensorParallelMode.ROW,
567+
device=device)
568+
dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard)
556569

557570
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
558571
self.load_expert_all_weight_scale_fp8_block_scale(

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DeepGemmFusedMoE
3131
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \
3232
AlltoallMethodType
33+
from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode
3334
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
3435
from tensorrt_llm._utils import mpi_rank
3536
from tensorrt_llm.mapping import Mapping
@@ -561,20 +562,22 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor,
561562

562563
@skip_non_hopper_unittest
563564
@pytest.mark.parametrize(
564-
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",
565+
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode",
565566
product(
566567
[torch.bfloat16],
567568
[72],
568569
[128, 256, 384, 512, 1024, 2048, 4096, 8192],
569570
[2560],
570571
[DefaultMoeRoutingMethod],
572+
[MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ],
571573
),
572574
)
573575
def test_fused_moe_fp8_blockwise(dtype,
574576
num_experts,
575577
seq_len,
576578
hidden_size,
577579
RoutingMethodCls,
580+
WeightLoadingMode,
578581
mapping=None):
579582
SEQ_LEN = seq_len
580583
HIDDEN_SIZE = hidden_size
@@ -600,6 +603,13 @@ def test_fused_moe_fp8_blockwise(dtype,
600603
device="cuda")
601604

602605
weights = {}
606+
607+
if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
608+
weights['gate_up_proj'] = {}
609+
weights['down_proj'] = {}
610+
weights['gate_up_proj_weight_scale'] = {}
611+
weights['down_proj_weight_scale'] = {}
612+
603613
for expert_id in range(NUM_EXPERTS):
604614
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
605615
dtype=dtype,
@@ -626,13 +636,26 @@ def test_fused_moe_fp8_blockwise(dtype,
626636
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
627637
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
628638
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
629-
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
630-
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
631-
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
632639
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale
633640
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale
634641
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale
635642

643+
if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
644+
weights['gate_up_proj'][expert_id] = torch.cat(
645+
[w3_weight_fp8, w1_weight_fp8],
646+
dim=-2).transpose(0, 1).contiguous()
647+
weights['down_proj'][expert_id] = w2_weight_fp8.transpose(
648+
0, 1).contiguous()
649+
weights['gate_up_proj_weight_scale'][expert_id] = torch.cat(
650+
[w3_weight_scale, w1_weight_scale],
651+
dim=-2).transpose(0, 1).contiguous()
652+
weights['down_proj_weight_scale'][
653+
expert_id] = w2_weight_scale.transpose(0, 1).contiguous()
654+
elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA:
655+
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
656+
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
657+
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
658+
636659
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES)
637660

638661
fused_moe = CuteDslFusedMoE(
@@ -643,6 +666,7 @@ def test_fused_moe_fp8_blockwise(dtype,
643666
dtype=dtype,
644667
reduce_results=True,
645668
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
669+
weight_loading_mode=WeightLoadingMode,
646670
)
647671
fused_moe.cuda()
648672
fused_moe.load_weights([weights])
@@ -655,6 +679,7 @@ def test_fused_moe_fp8_blockwise(dtype,
655679
dtype=dtype,
656680
reduce_results=True,
657681
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
682+
weight_loading_mode=WeightLoadingMode,
658683
)
659684
fused_moe_origin.cuda()
660685
fused_moe_origin.load_weights([weights])

0 commit comments

Comments
 (0)