1515
1616from tests .kernels .quant_utils import native_w8a8_block_matmul
1717from vllm .attention import AttentionType
18- from vllm .model_executor .layers .activation import SiluAndMul
18+ from vllm .model_executor .layers .activation import get_act_and_mul_fn
1919from vllm .model_executor .layers .fused_moe .utils import moe_kernel_quantize_input
2020from vllm .utils import (
2121 STR_BACKEND_ENV_VAR ,
@@ -903,6 +903,7 @@ def torch_experts(
903903 w2 : torch .Tensor ,
904904 topk_weight : torch .Tensor ,
905905 topk_ids : torch .Tensor ,
906+ activation : str ,
906907 global_num_experts : int = - 1 ,
907908 w1_bias : torch .Tensor | None = None ,
908909 w2_bias : torch .Tensor | None = None ,
@@ -947,14 +948,19 @@ def torch_experts(
947948
948949 f32 = torch .float32
949950
951+ def apply_moe_activation (act_str , act_input ):
952+ act_fn = get_act_and_mul_fn (act_str )
953+ return act_fn (act_input )
954+
950955 for i in range (num_experts ):
951956 mask = topk_ids == i
952957 if mask .sum ():
953958 if quant_dtype is None :
954959 tmp1 = a [mask ] @ w1 [i ].transpose (0 , 1 )
955960 if w1_bias is not None :
956961 tmp1 = tmp1 + w1_bias [i ].view (1 , - 1 ).to (tmp1 .dtype )
957- tmp2 = SiluAndMul ()(tmp1 )
962+
963+ tmp2 = apply_moe_activation (activation , tmp1 )
958964 out [mask ] = tmp2 @ w2 [i ].transpose (0 , 1 )
959965 if w2_bias is not None :
960966 out [mask ] = out [mask ] + w2_bias [i ].view (1 , - 1 ).to (tmp1 .dtype )
@@ -970,7 +976,9 @@ def torch_experts(
970976 )
971977 if w1_bias is not None :
972978 tmp1 = tmp1 + w1_bias [i ].view (1 , - 1 ).to (tmp1 .dtype )
973- tmp2 = SiluAndMul ()(tmp1 )
979+
980+ tmp2 = apply_moe_activation (activation , tmp1 )
981+
974982 tmp2 , b_scale = moe_kernel_quantize_input (
975983 tmp2 , a2_scale , quant_dtype , per_act_token_quant , block_shape
976984 )
@@ -994,7 +1002,7 @@ def torch_experts(
9941002 if w1_bias is not None :
9951003 tmp1 = tmp1 + w1_bias [i ].view (1 , - 1 ).to (out .dtype )
9961004
997- tmp2 = SiluAndMul ()( tmp1 ). to ( out . dtype )
1005+ tmp2 = apply_moe_activation ( activation , tmp1 )
9981006
9991007 tmp2 , b_scale = moe_kernel_quantize_input (
10001008 tmp2 , a2_scale , quant_dtype , per_act_token_quant , block_shape
0 commit comments