Skip to content

Support Llama 4 for fused_marlin_moe #20457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
mul_topk_weights=apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M,
Expand Down Expand Up @@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=True,
mul_topk_weights=not apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M * topk,
Expand All @@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
Expand Down
6 changes: 1 addition & 5 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,6 @@ def apply(

assert activation == "silu", "Only SiLU activation is supported."

if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -520,6 +515,7 @@ def apply(
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def apply(
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)

Expand Down Expand Up @@ -632,8 +633,6 @@ def apply(
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
Expand All @@ -644,6 +643,7 @@ def apply(
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)

Expand Down Expand Up @@ -1312,8 +1312,6 @@ def apply(

assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -1337,6 +1335,7 @@ def apply(
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_weight_g_idx,
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,8 +892,6 @@ def apply(
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
Expand All @@ -904,6 +902,7 @@ def apply(
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
else:
Expand Down
5 changes: 1 addition & 4 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,10 +645,6 @@ def apply(
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")

assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
"fused Marlin MoE method.")

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -672,6 +668,7 @@ def apply(
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def apply(
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)

Expand Down