Skip to content

Commit 001ac2e

Browse files
author
Varun Sundar Rabindranath
committed
refactor
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
1 parent 9296688 commit 001ac2e

File tree

9 files changed

+43
-49
lines changed

9 files changed

+43
-49
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ def workspace_shapes(
4747
N: int,
4848
K: int,
4949
topk: int,
50-
num_experts: int,
50+
global_num_experts: int,
51+
local_num_experts: int,
5152
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
5253
assert a.dim() == 2
5354
num_dp = self.dp_size
55+
num_experts = local_num_experts
5456
max_num_tokens = a.size(
5557
0) if self.max_num_tokens is None else self.max_num_tokens
5658
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,19 @@ def workspace_shapes(
8181
N: int,
8282
K: int,
8383
topk: int,
84-
num_experts: int,
84+
global_num_experts: int,
85+
local_num_experts: int,
8586
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
8687
# Note: the deep gemm workspaces are strictly larger than the triton
8788
# workspaces so we can be pessimistic here and allocate for DeepGemm
8889
# even if we fall back to triton later, e.g. if expert maps are set.
8990
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
9091
return self.batched_deep_gemm_experts.workspace_shapes(
91-
a, aq, M, N, K, topk, num_experts)
92+
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
9293
else:
9394
assert self.batched_triton_experts is not None
9495
return self.batched_triton_experts.workspace_shapes(
95-
a, aq, M, N, K, topk, num_experts)
96+
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
9697

9798
def apply(
9899
self,

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ def workspace_shapes(
230230
N: int,
231231
K: int,
232232
topk: int,
233-
num_experts: int,
233+
global_num_experts: int,
234+
local_num_experts: int,
234235
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
235236
workspace1: tuple[int, ...] = ()
236237
workspace2: tuple[int, ...] = ()

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,12 @@ def supports_chunking(self) -> bool:
7474
return True
7575

7676
def workspace_shapes(
77-
self,
78-
a: torch.Tensor,
79-
aq: torch.Tensor,
80-
M: int,
81-
N: int,
82-
K: int,
83-
topk: int,
84-
num_experts: int,
77+
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
78+
topk: int, global_num_experts: int, local_num_experts: int
8579
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
80+
# We use global_num_experts due to how moe_align_block_size handles
81+
# expert_maps.
82+
num_experts = global_num_experts
8683
block_m = self.block_shape[0]
8784
M_sum = (M * topk) + num_experts * (block_m - 1)
8885
M_sum = round_up(M_sum, block_m)

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,12 @@ def workspace_shapes(
521521
N: int,
522522
K: int,
523523
topk: int,
524-
num_experts: int,
524+
global_num_experts: int,
525+
local_num_experts: int,
525526
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
526527
assert a.dim() == 2
527528
num_dp = self.dp_size
529+
num_experts = local_num_experts
528530
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
529531
workspace2 = (self.max_num_tokens * num_dp, N)
530532
return (workspace13, workspace2, workspace13, a.dtype)
@@ -624,10 +626,12 @@ def workspace_shapes(
624626
N: int,
625627
K: int,
626628
topk: int,
627-
num_experts: int,
629+
global_num_experts: int,
630+
local_num_experts: int,
628631
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
629632
assert a.dim() == 2
630633
num_dp = self.world_size // self.dp_size
634+
num_experts = local_num_experts
631635
max_num_tokens = a.size(
632636
0) if self.max_num_tokens is None else self.max_num_tokens
633637
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1553,7 +1553,8 @@ def workspace_shapes(
15531553
N: int,
15541554
K: int,
15551555
topk: int,
1556-
num_experts: int,
1556+
global_num_experts: int,
1557+
local_num_experts: int,
15571558
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
15581559
workspace1 = (M, topk, max(N * 2, K))
15591560
workspace2 = (M, topk, N)

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ def workspace_shapes(
194194
N: int,
195195
K: int,
196196
topk: int,
197-
num_experts: int,
197+
global_num_experts: int,
198+
local_num_experts: int,
198199
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
199200
"""
200201
Compute the shapes for the temporary and final outputs of the two gemms
@@ -372,8 +373,9 @@ def forward(
372373
a1 = hidden_states
373374
output = a1 if inplace else torch.zeros_like(a1)
374375

376+
local_num_experts = w1.size(0)
375377
if global_num_experts == -1:
376-
global_num_experts = w1.size(0)
378+
global_num_experts = local_num_experts
377379

378380
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
379381
_expert_topk_weights) = self.prepare_finalize.prepare(
@@ -405,44 +407,22 @@ def forward(
405407
CHUNK_SIZE = M
406408
num_chunks = 1
407409

408-
# Batched experts don't support chunking at this level as the
409-
# chunking had already happened at an higher level - in
410-
# fused_moe/layer.py
411-
is_batched_fused_experts = not self.fused_experts.supports_chunking(
412-
)
413-
414-
# TODO (varun): In the case of a non-batched fused_experts
415-
# implementation the input tokens are usually aligned to a
416-
# "block-size" by moe_align_block_size. In the case of
417-
# expert_parallel, moe_align_block_size initially considers all
418-
# experts as valid and aligns all tokens appropriately. Before
419-
# moe_align_block_size returns it marks the experts_ids that are
420-
# not in the current GPU rank as -1 so the MoE matmuls could skip
421-
# those blocks. This is sub-optimal.
422-
# Due to how moe_align_block_size is implemented at the
423-
# moment, it is required that we use `global_num_experts` in the
424-
# workspace calculations. However for the batched case, we don't
425-
# use `moe_align_block_size`, as the input is already aligned
426-
# (batched). This lets us use `local_num_experts`, which is
427-
# much lesser than global_num_experts, in the workspace
428-
# calculation.
429-
num_experts_workspace = w1.size(
430-
0) if is_batched_fused_experts else global_num_experts
431-
432410
if num_chunks == 1:
433411
(workspace13_shape, workspace2_shape, fused_out_shape,
434412
workspace_dtype) = self.fused_experts.workspace_shapes(
435-
a1, a1q, M, N, K, top_k, num_experts_workspace)
413+
a1, a1q, M, N, K, top_k, global_num_experts,
414+
local_num_experts)
436415
else:
437416
# Use the full M to get the final output shape.
438417
_, _, fused_out_shape, _ = (
439418
self.fused_experts.workspace_shapes(
440-
a1, a1q, M, N, K, top_k, num_experts_workspace))
419+
a1, a1q, M, N, K, top_k, global_num_experts,
420+
local_num_experts))
441421
# Use the CHUNK_SIZE to get the workspace shapes.
442422
workspace13_shape, workspace2_shape, _, workspace_dtype = (
443423
self.fused_experts.workspace_shapes(
444-
a1, a1q, CHUNK_SIZE, N, K, top_k,
445-
num_experts_workspace))
424+
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
425+
local_num_experts))
446426

447427
# We can reuse the memory between cache1 and cache3 because by the
448428
# time we need cache3, we're done with cache1.

vllm/model_executor/layers/fused_moe/moe_align_block_size.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def moe_align_block_size(
159159
Aligns the token distribution across experts to be compatible with block
160160
size for matrix multiplication.
161161
162+
Note: In the case of expert_parallel, moe_align_block_size initially
163+
considers all experts as valid and aligns all tokens appropriately.
164+
Before the function returns it marks the experts_ids that are not in
165+
the current GPU rank as -1 so the MoE matmuls could skip those blocks.
166+
This requires the num_experts input arg to be the num global experts.
167+
162168
Parameters:
163169
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
164170
top-k expert indices for each token.

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,20 @@ def workspace_shapes(
4848
N: int,
4949
K: int,
5050
topk: int,
51-
num_experts: int,
51+
global_num_experts: int,
52+
local_num_experts: int,
5253
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
5354
# Note: the deep gemm workspaces are strictly larger than the triton
5455
# workspaces so we can be pessimistic here and allocate for DeepGemm
5556
# even if we fall back to triton later, e.g. if expert maps are set.
5657
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
5758
assert self.deep_gemm_expert is not None
5859
return self.deep_gemm_expert.workspace_shapes(
59-
a, aq, M, N, K, topk, num_experts)
60+
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
6061
else:
6162
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
62-
num_experts)
63+
global_num_experts,
64+
local_num_experts)
6365

6466
def apply(
6567
self,

0 commit comments

Comments
 (0)