Skip to content

Commit 42d2563

Browse files
cyrusd98facebook-github-bot
authored andcommitted
clean up activation hardcoding to silu
Summary: GPT OSS has a different activation function from Llama4. This diff removes activation function hardcoding to allow general fused moe support. Test Plan: CI Passes Differential Revision: D86720793
1 parent 40e6f1f commit 42d2563

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

tests/kernels/utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from tests.kernels.quant_utils import native_w8a8_block_matmul
1717
from vllm.attention import AttentionType
18-
from vllm.model_executor.layers.activation import SiluAndMul
18+
from vllm.model_executor.layers.activation import get_act_and_mul_fn
1919
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
2020
from vllm.utils import (
2121
STR_BACKEND_ENV_VAR,
@@ -903,6 +903,7 @@ def torch_experts(
903903
w2: torch.Tensor,
904904
topk_weight: torch.Tensor,
905905
topk_ids: torch.Tensor,
906+
activation: str,
906907
global_num_experts: int = -1,
907908
w1_bias: torch.Tensor | None = None,
908909
w2_bias: torch.Tensor | None = None,
@@ -947,14 +948,19 @@ def torch_experts(
947948

948949
f32 = torch.float32
949950

951+
def apply_moe_activation(act_str, act_input):
952+
act_fn = get_act_and_mul_fn(act_str)
953+
return act_fn(act_input)
954+
950955
for i in range(num_experts):
951956
mask = topk_ids == i
952957
if mask.sum():
953958
if quant_dtype is None:
954959
tmp1 = a[mask] @ w1[i].transpose(0, 1)
955960
if w1_bias is not None:
956961
tmp1 = tmp1 + w1_bias[i].view(1, -1).to(tmp1.dtype)
957-
tmp2 = SiluAndMul()(tmp1)
962+
963+
tmp2 = apply_moe_activation(activation, tmp1)
958964
out[mask] = tmp2 @ w2[i].transpose(0, 1)
959965
if w2_bias is not None:
960966
out[mask] = out[mask] + w2_bias[i].view(1, -1).to(tmp1.dtype)
@@ -970,7 +976,9 @@ def torch_experts(
970976
)
971977
if w1_bias is not None:
972978
tmp1 = tmp1 + w1_bias[i].view(1, -1).to(tmp1.dtype)
973-
tmp2 = SiluAndMul()(tmp1)
979+
980+
tmp2 = apply_moe_activation(activation, tmp1)
981+
974982
tmp2, b_scale = moe_kernel_quantize_input(
975983
tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape
976984
)
@@ -994,7 +1002,7 @@ def torch_experts(
9941002
if w1_bias is not None:
9951003
tmp1 = tmp1 + w1_bias[i].view(1, -1).to(out.dtype)
9961004

997-
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
1005+
tmp2 = apply_moe_activation(activation, tmp1)
9981006

9991007
tmp2, b_scale = moe_kernel_quantize_input(
10001008
tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape

0 commit comments

Comments
 (0)