Skip to content

Commit 66c457b

Browse files
committed
Add unit test case that would fail without filling zeros to c1
Signed-off-by: Ming Yang <[email protected]>
1 parent 890ed4f commit 66c457b

File tree

2 files changed

+132
-2
lines changed

2 files changed

+132
-2
lines changed

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import dataclasses
4+
from math import prod
45
from typing import Optional
56

67
import pytest
78
import torch
89

910
from vllm import _custom_ops as ops
1011
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
11-
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
12+
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
13+
cutlass_moe_fp8, run_cutlass_moe_fp8)
1214
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
1315
fused_topk)
16+
from vllm.model_executor.layers.fused_moe.utils import (
17+
moe_kernel_quantize_input)
1418
from vllm.platforms import current_platform
1519

1620
NUM_EXPERTS = [40, 64]
@@ -365,3 +369,129 @@ def test_cutlass_moe_8_bit_EP(
365369
cutlass_output,
366370
atol=5e-2,
367371
rtol=1e-2)
372+
373+
374+
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
375+
@pytest.mark.parametrize("e", [128])
376+
@pytest.mark.parametrize("per_act_token", [False])
377+
@pytest.mark.parametrize("per_out_channel", [True])
378+
@pytest.mark.parametrize("ep_size", [8])
379+
@pytest.mark.skipif(
380+
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
381+
current_platform.get_device_capability()),
382+
reason="Grouped gemm is not supported on this GPU type.")
383+
def test_cutlass_moe_8_bit_EP_large(
384+
m: int,
385+
n: int,
386+
k: int,
387+
e: int,
388+
topk: int,
389+
per_act_token: bool,
390+
per_out_channel: bool,
391+
ep_size: int,
392+
):
393+
current_platform.seed_everything(7)
394+
with set_current_vllm_config(vllm_config):
395+
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
396+
per_out_channel)
397+
398+
score = torch.randn((m, e), device="cuda", dtype=torch.half)
399+
topk_weights, topk_ids, _ = fused_topk(mt.a,
400+
score,
401+
topk,
402+
renormalize=False)
403+
404+
# Note that we are using the dequantized versions of the tensors.
405+
# Using a, w1 and w2 directly results in minor output differences.
406+
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
407+
topk_ids)
408+
409+
assert e % ep_size == 0, "Cannot distribute experts evenly"
410+
cutlass_output = run_8_bit(mt,
411+
topk_weights,
412+
topk_ids,
413+
num_local_experts=e // ep_size)
414+
415+
torch.testing.assert_close(triton_output,
416+
cutlass_output,
417+
atol=5e-2,
418+
rtol=1e-2)
419+
420+
421+
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
422+
@pytest.mark.parametrize("e", [128])
423+
@pytest.mark.parametrize("per_act_token", [False])
424+
@pytest.mark.parametrize("per_out_channel", [True])
425+
@pytest.mark.parametrize("ep_size", [8])
426+
@pytest.mark.skipif(
427+
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
428+
current_platform.get_device_capability()),
429+
reason="Grouped gemm is not supported on this GPU type.")
430+
def test_run_cutlass_moe_fp8(
431+
m: int,
432+
n: int,
433+
k: int,
434+
e: int,
435+
topk: int,
436+
per_act_token: bool,
437+
per_out_channel: bool,
438+
ep_size: int,
439+
):
440+
current_platform.seed_everything(7)
441+
with set_current_vllm_config(vllm_config):
442+
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
443+
per_out_channel)
444+
445+
score = torch.randn((m, e), device="cuda", dtype=torch.half)
446+
topk_weights, topk_ids, _ = fused_topk(mt.a,
447+
score,
448+
topk,
449+
renormalize=False)
450+
# we want to make sure there is at least one token that's generated in this expert shard
451+
# and at least one token that's NOT generated in this expert shard
452+
topk_ids[0][0] = -1
453+
topk_ids[0][1] = 1
454+
455+
workspace13_shape = (m * topk, max(2 * n, k))
456+
workspace2_shape = (m * topk, n)
457+
output_shape = (m * topk, k)
458+
459+
workspace13 = torch.empty(prod(workspace13_shape),
460+
device="cuda",
461+
dtype=mt.a.dtype)
462+
workspace2 = torch.empty(prod(workspace2_shape),
463+
device="cuda",
464+
dtype=mt.a.dtype)
465+
466+
num_local_experts = e // ep_size
467+
start, end = 0, num_local_experts
468+
expert_map = [-1] * e
469+
expert_map[start:end] = list(range(num_local_experts))
470+
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
471+
472+
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
473+
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
474+
torch.float8_e4m3fn,
475+
per_act_token)
476+
func = lambda output: run_cutlass_moe_fp8(
477+
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, mt.w1_q.size(
478+
0), expert_map, mt.w1_scale, mt.w2_scale, a1q_scale, None,
479+
workspace13, workspace2, None, mt.a.dtype, per_act_token,
480+
per_out_channel, False)
481+
482+
workspace13.random_()
483+
output_random_workspace = torch.empty(output_shape,
484+
device="cuda",
485+
dtype=mt.a.dtype)
486+
func(output_random_workspace)
487+
488+
workspace13.fill_(0)
489+
output_zero_workspace = torch.zeros(output_shape,
490+
device="cuda",
491+
dtype=mt.a.dtype)
492+
func(output_zero_workspace)
493+
494+
torch.testing.assert_close(output_random_workspace,
495+
output_zero_workspace,
496+
atol=5e-3,
497+
rtol=1e-3)

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def apply(
270270
):
271271
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
272272
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
273-
activation_callable = lambda i, o: self.activation(activation, i, o)
273+
activation_callable = lambda o, i: self.activation(activation, o, i)
274274
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
275275
activation_callable, global_num_experts,
276276
expert_map, w1_scale, w2_scale, a1q_scale,

0 commit comments

Comments
 (0)