Skip to content

Commit 36017aa

Browse files
WoosukKwonjvmncs
authored andcommitted
Add fused top-K softmax kernel for MoE (vllm-project#2769)
1 parent 419b31d commit 36017aa

File tree

9 files changed

+591
-50
lines changed

9 files changed

+591
-50
lines changed

csrc/moe/moe_ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include "moe_ops.h"
2+
3+
#include <torch/extension.h>
4+
5+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6+
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
7+
}

csrc/moe/moe_ops.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
void topk_softmax(
6+
torch::Tensor& topk_weights,
7+
torch::Tensor& topk_indices,
8+
torch::Tensor& token_expert_indices,
9+
torch::Tensor& gating_output);

csrc/moe/topk_softmax_kernels.cu

Lines changed: 499 additions & 0 deletions
Large diffs are not rendered by default.

csrc/pybind.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4848
&rotary_embedding,
4949
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
5050

51-
#ifndef USE_ROCM
5251
// Quantization ops
52+
#ifndef USE_ROCM
5353
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
5454
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
5555
#endif

setup.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,17 @@ def get_torch_arch_list() -> Set[str]:
339339
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
340340
vllm_extension_sources.append("csrc/custom_all_reduce.cu")
341341

342+
# Add MoE kernels.
343+
ext_modules.append(
344+
CUDAExtension(
345+
name="vllm._moe_C",
346+
sources=glob("csrc/moe/*.cu") + glob("csrc/moe/*.cpp"),
347+
extra_compile_args={
348+
"cxx": CXX_FLAGS,
349+
"nvcc": NVCC_FLAGS,
350+
},
351+
))
352+
342353
if not _is_neuron():
343354
vllm_extension = CUDAExtension(
344355
name="vllm._C",

tests/kernels/test_moe.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
33
Run `pytest tests/kernels/test_moe.py`.
44
"""
5-
65
import pytest
76
import torch
8-
97
from transformers import MixtralConfig
108
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
119

@@ -14,22 +12,21 @@
1412
from vllm.model_executor.models.mixtral import MixtralMoE
1513

1614

17-
def torch_moe(a, w1, w2, topk_weight, topk_ids):
15+
def torch_moe(a, w1, w2, score, topk):
1816
B, D = a.shape
19-
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
20-
out = torch.zeros(B * topk_ids.shape[1],
21-
w2.shape[1],
22-
dtype=a.dtype,
23-
device=a.device)
24-
topk_ids = topk_ids.view(-1)
17+
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
18+
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
19+
score = torch.softmax(score, dim=-1, dtype=torch.float32)
20+
topk_weight, topk_ids = torch.topk(score, topk)
2521
topk_weight = topk_weight.view(-1)
22+
topk_ids = topk_ids.view(-1)
2623
for i in range(w1.shape[0]):
2724
mask = topk_ids == i
2825
if mask.sum():
2926
out[mask] = SiluAndMul()(
3027
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
3128
return (out.view(B, -1, w2.shape[1]) *
32-
topk_weight.view(B, -1, 1)).sum(dim=1)
29+
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
3330

3431

3532
@pytest.mark.parametrize("m", [512, 222, 33, 1])
@@ -51,11 +48,8 @@ def test_fused_moe(
5148
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
5249

5350
score = torch.randn((m, e), device='cuda', dtype=dtype)
54-
score = torch.softmax(score, dim=-1)
55-
topk_weight, topk_ids = torch.topk(score, topk)
56-
57-
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
58-
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
51+
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
52+
torch_output = torch_moe(a, w1, w2, score, topk)
5953
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
6054

6155

@@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype):
7569
intermediate_size=config.intermediate_size,
7670
params_dtype=dtype,
7771
tp_size=1,
78-
)
72+
).cuda()
7973

8074
# Load the weights
8175
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data

vllm/model_executor/layers/fused_moe.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import triton.language as tl
55

66
from vllm._C import ops
7+
from vllm.utils import is_hip
78

89

910
@triton.jit
@@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
177178
expert_ids: torch.Tensor,
178179
num_tokens_post_padded: torch.Tensor,
179180
mul_routed_weight: bool, top_k: int, config: dict):
180-
181181
assert topk_weights.stride(1) == 1
182182
assert sorted_token_ids.stride(0) == 1
183183

@@ -210,28 +210,35 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
210210
)
211211

212212

213-
def fused_moe(hidden_states: torch.Tensor,
214-
w1: torch.Tensor,
215-
w2: torch.Tensor,
216-
topk_weights: torch.Tensor,
217-
topk_ids: torch.Tensor,
218-
inplace=False):
213+
def fused_moe(
214+
hidden_states: torch.Tensor,
215+
w1: torch.Tensor,
216+
w2: torch.Tensor,
217+
gating_output: torch.Tensor,
218+
topk: int,
219+
renormalize: bool,
220+
inplace: bool = False,
221+
) -> torch.Tensor:
219222
"""
220223
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
221224
222225
Parameters:
223226
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
224227
- w1 (torch.Tensor): The first set of expert weights.
225228
- w2 (torch.Tensor): The second set of expert weights.
226-
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
227-
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
229+
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
230+
- topk (int): The number of top-k experts to select.
231+
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
228232
- inplace (bool): If True, perform the operation in-place. Defaults to False.
229233
230234
Returns:
231235
- torch.Tensor: The output tensor after applying the MoE layer.
232236
"""
233237
# Check constraints.
234-
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions"
238+
assert hidden_states.shape[0] == gating_output.shape[0], (
239+
"Number of tokens mismatch")
240+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
241+
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
235242
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
236243
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
237244
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
@@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor,
241248
M, _ = hidden_states.shape
242249
E, N, _ = w1.shape
243250

251+
if is_hip():
252+
# The MoE kernels are not yet supported on ROCm.
253+
routing_weights = torch.softmax(gating_output,
254+
dim=-1,
255+
dtype=torch.float32)
256+
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
257+
else:
258+
import vllm._moe_C as moe_kernels
259+
260+
topk_weights = torch.empty(M,
261+
topk,
262+
dtype=torch.float32,
263+
device=hidden_states.device)
264+
topk_ids = torch.empty(M,
265+
topk,
266+
dtype=torch.int32,
267+
device=hidden_states.device)
268+
token_expert_indicies = torch.empty(M,
269+
topk,
270+
dtype=torch.int32,
271+
device=hidden_states.device)
272+
moe_kernels.topk_softmax(
273+
topk_weights,
274+
topk_ids,
275+
token_expert_indicies,
276+
gating_output.float(), # TODO(woosuk): Optimize this.
277+
)
278+
del token_expert_indicies # Not used. Will be used in the future.
279+
if renormalize:
280+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
281+
244282
config = {
245283
'BLOCK_SIZE_M': 64,
246284
'BLOCK_SIZE_N': 64,

vllm/model_executor/models/deepseek.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import torch
2727
from torch import nn
28-
import torch.nn.functional as F
2928
from transformers import PretrainedConfig
3029

3130
from vllm.model_executor.input_metadata import InputMetadata
@@ -155,20 +154,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
155154
shared_output = self.shared_experts(hidden_states)
156155
# router_logits: (batch * sequence_length, n_experts)
157156
router_logits, _ = self.gate(hidden_states)
158-
159-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
160-
routing_weights, selected_experts = torch.topk(routing_weights,
161-
self.top_k,
162-
dim=-1)
163-
164-
if self.config.norm_topk_prob:
165-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
166-
167157
final_hidden_states = fused_moe(hidden_states,
168158
self.w1,
169159
self.w2,
170-
routing_weights,
171-
selected_experts,
160+
router_logits,
161+
self.top_k,
162+
renormalize=self.config.norm_topk_prob,
172163
inplace=True)
173164

174165
if self.config.n_shared_experts is not None:

vllm/model_executor/models/mixtral.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
from typing import List, Optional, Tuple
2525

2626
import torch
27-
import torch.nn.functional as F
28-
2927
from torch import nn
3028
from transformers import MixtralConfig
3129

@@ -128,18 +126,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
128126
hidden_states = hidden_states.view(-1, self.hidden_size)
129127
# router_logits: (batch * sequence_length, n_experts)
130128
router_logits, _ = self.gate(hidden_states)
131-
132-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
133-
routing_weights, selected_experts = torch.topk(routing_weights,
134-
self.top_k,
135-
dim=-1)
136-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
137-
138129
final_hidden_states = fused_moe(hidden_states,
139130
self.ws,
140131
self.w2s,
141-
routing_weights,
142-
selected_experts,
132+
router_logits,
133+
self.top_k,
134+
renormalize=True,
143135
inplace=True)
144136

145137
if self.tp_size > 1:

0 commit comments

Comments
 (0)