|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 | import dataclasses
|
| 4 | +from math import prod |
4 | 5 | from typing import Optional
|
5 | 6 |
|
6 | 7 | import pytest
|
7 | 8 | import torch
|
8 | 9 |
|
9 | 10 | from vllm import _custom_ops as ops
|
10 | 11 | from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
11 |
| -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 |
| 12 | +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( |
| 13 | + cutlass_moe_fp8, run_cutlass_moe_fp8) |
12 | 14 | from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
13 | 15 | fused_topk)
|
| 16 | +from vllm.model_executor.layers.fused_moe.utils import ( |
| 17 | + moe_kernel_quantize_input) |
14 | 18 | from vllm.platforms import current_platform
|
15 | 19 |
|
16 | 20 | NUM_EXPERTS = [40, 64]
|
@@ -365,3 +369,129 @@ def test_cutlass_moe_8_bit_EP(
|
365 | 369 | cutlass_output,
|
366 | 370 | atol=5e-2,
|
367 | 371 | rtol=1e-2)
|
| 372 | + |
| 373 | + |
| 374 | +@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)]) |
| 375 | +@pytest.mark.parametrize("e", [128]) |
| 376 | +@pytest.mark.parametrize("per_act_token", [False]) |
| 377 | +@pytest.mark.parametrize("per_out_channel", [True]) |
| 378 | +@pytest.mark.parametrize("ep_size", [8]) |
| 379 | +@pytest.mark.skipif( |
| 380 | + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( |
| 381 | + current_platform.get_device_capability()), |
| 382 | + reason="Grouped gemm is not supported on this GPU type.") |
| 383 | +def test_cutlass_moe_8_bit_EP_large( |
| 384 | + m: int, |
| 385 | + n: int, |
| 386 | + k: int, |
| 387 | + e: int, |
| 388 | + topk: int, |
| 389 | + per_act_token: bool, |
| 390 | + per_out_channel: bool, |
| 391 | + ep_size: int, |
| 392 | +): |
| 393 | + current_platform.seed_everything(7) |
| 394 | + with set_current_vllm_config(vllm_config): |
| 395 | + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, |
| 396 | + per_out_channel) |
| 397 | + |
| 398 | + score = torch.randn((m, e), device="cuda", dtype=torch.half) |
| 399 | + topk_weights, topk_ids, _ = fused_topk(mt.a, |
| 400 | + score, |
| 401 | + topk, |
| 402 | + renormalize=False) |
| 403 | + |
| 404 | + # Note that we are using the dequantized versions of the tensors. |
| 405 | + # Using a, w1 and w2 directly results in minor output differences. |
| 406 | + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, |
| 407 | + topk_ids) |
| 408 | + |
| 409 | + assert e % ep_size == 0, "Cannot distribute experts evenly" |
| 410 | + cutlass_output = run_8_bit(mt, |
| 411 | + topk_weights, |
| 412 | + topk_ids, |
| 413 | + num_local_experts=e // ep_size) |
| 414 | + |
| 415 | + torch.testing.assert_close(triton_output, |
| 416 | + cutlass_output, |
| 417 | + atol=5e-2, |
| 418 | + rtol=1e-2) |
| 419 | + |
| 420 | + |
| 421 | +@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)]) |
| 422 | +@pytest.mark.parametrize("e", [128]) |
| 423 | +@pytest.mark.parametrize("per_act_token", [False]) |
| 424 | +@pytest.mark.parametrize("per_out_channel", [True]) |
| 425 | +@pytest.mark.parametrize("ep_size", [8]) |
| 426 | +@pytest.mark.skipif( |
| 427 | + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( |
| 428 | + current_platform.get_device_capability()), |
| 429 | + reason="Grouped gemm is not supported on this GPU type.") |
| 430 | +def test_run_cutlass_moe_fp8( |
| 431 | + m: int, |
| 432 | + n: int, |
| 433 | + k: int, |
| 434 | + e: int, |
| 435 | + topk: int, |
| 436 | + per_act_token: bool, |
| 437 | + per_out_channel: bool, |
| 438 | + ep_size: int, |
| 439 | +): |
| 440 | + current_platform.seed_everything(7) |
| 441 | + with set_current_vllm_config(vllm_config): |
| 442 | + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, |
| 443 | + per_out_channel) |
| 444 | + |
| 445 | + score = torch.randn((m, e), device="cuda", dtype=torch.half) |
| 446 | + topk_weights, topk_ids, _ = fused_topk(mt.a, |
| 447 | + score, |
| 448 | + topk, |
| 449 | + renormalize=False) |
| 450 | + # we want to make sure there is at least one token that's generated in this expert shard |
| 451 | + # and at least one token that's NOT generated in this expert shard |
| 452 | + topk_ids[0][0] = -1 |
| 453 | + topk_ids[0][1] = 1 |
| 454 | + |
| 455 | + workspace13_shape = (m * topk, max(2 * n, k)) |
| 456 | + workspace2_shape = (m * topk, n) |
| 457 | + output_shape = (m * topk, k) |
| 458 | + |
| 459 | + workspace13 = torch.empty(prod(workspace13_shape), |
| 460 | + device="cuda", |
| 461 | + dtype=mt.a.dtype) |
| 462 | + workspace2 = torch.empty(prod(workspace2_shape), |
| 463 | + device="cuda", |
| 464 | + dtype=mt.a.dtype) |
| 465 | + |
| 466 | + num_local_experts = e // ep_size |
| 467 | + start, end = 0, num_local_experts |
| 468 | + expert_map = [-1] * e |
| 469 | + expert_map[start:end] = list(range(num_local_experts)) |
| 470 | + expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") |
| 471 | + |
| 472 | + activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) |
| 473 | + a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, |
| 474 | + torch.float8_e4m3fn, |
| 475 | + per_act_token) |
| 476 | + func = lambda output: run_cutlass_moe_fp8( |
| 477 | + output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, mt.w1_q.size( |
| 478 | + 0), expert_map, mt.w1_scale, mt.w2_scale, a1q_scale, None, |
| 479 | + workspace13, workspace2, None, mt.a.dtype, per_act_token, |
| 480 | + per_out_channel, False) |
| 481 | + |
| 482 | + workspace13.random_() |
| 483 | + output_random_workspace = torch.empty(output_shape, |
| 484 | + device="cuda", |
| 485 | + dtype=mt.a.dtype) |
| 486 | + func(output_random_workspace) |
| 487 | + |
| 488 | + workspace13.fill_(0) |
| 489 | + output_zero_workspace = torch.zeros(output_shape, |
| 490 | + device="cuda", |
| 491 | + dtype=mt.a.dtype) |
| 492 | + func(output_zero_workspace) |
| 493 | + |
| 494 | + torch.testing.assert_close(output_random_workspace, |
| 495 | + output_zero_workspace, |
| 496 | + atol=5e-3, |
| 497 | + rtol=1e-3) |
0 commit comments