Skip to content

Commit a091e2d

Browse files
ElizaWszoladsikka
andauthored
[Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032)
Co-authored-by: Dipika <[email protected]>
1 parent fc990f9 commit a091e2d

File tree

12 files changed

+453
-185
lines changed

12 files changed

+453
-185
lines changed

csrc/moe/marlin_moe_ops.cu

Lines changed: 389 additions & 148 deletions
Large diffs are not rendered by default.

csrc/moe/marlin_moe_ops.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
#include <torch/all.h>
44

5+
#include "core/scalar_type.hpp"
6+
57
torch::Tensor marlin_gemm_moe(
68
const torch::Tensor& a, const torch::Tensor& b_q_weights,
79
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
810
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
911
const torch::Tensor& g_idx, const torch::Tensor& perm,
10-
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
11-
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
12+
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
13+
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
14+
int64_t num_experts, int64_t topk, int64_t moe_block_size,
1215
bool replicate_input, bool apply_weights);

csrc/moe/torch_bindings.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
1313
m.def(
1414
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
1515
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
16-
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
17-
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
18-
"bool replicate_input, bool apply_weights) -> Tensor");
16+
"g_idx, Tensor! perm, Tensor! workspace, "
17+
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
18+
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
19+
"int moe_block_size, bool replicate_input, bool apply_weights)"
20+
" -> Tensor");
1921
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
2022
#endif
2123
}

tests/kernels/test_moe.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref):
140140
@pytest.mark.parametrize("topk", [2, 6])
141141
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
142142
@pytest.mark.parametrize("act_order", [True, False])
143+
@pytest.mark.parametrize("num_bits", [4, 8])
143144
def test_fused_marlin_moe(
144145
m: int,
145146
n: int,
@@ -148,6 +149,7 @@ def test_fused_marlin_moe(
148149
topk: int,
149150
group_size: int,
150151
act_order: bool,
152+
num_bits: int,
151153
):
152154
torch.manual_seed(7)
153155

@@ -161,13 +163,12 @@ def test_fused_marlin_moe(
161163
if group_size in (k, n):
162164
return
163165

164-
quant_type = scalar_types.uint4b8
166+
quant_type = (scalar_types.uint4b8
167+
if num_bits == 4 else scalar_types.uint8b128)
165168
dtype = torch.float16
166169
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
167170
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
168171
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
169-
for i in range(w2.shape[0]):
170-
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
171172

172173
w_ref1_l = []
173174
qweight1_l = []
@@ -240,6 +241,7 @@ def test_fused_marlin_moe(
240241
topk_ids,
241242
w1_scale=scales1,
242243
w2_scale=scales2,
244+
num_bits=num_bits,
243245
)
244246

245247
assert compute_max_diff(marlin_output, triton_output) < 4e-2
@@ -254,14 +256,16 @@ def test_fused_marlin_moe(
254256
@pytest.mark.parametrize("topk", [2, 6])
255257
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
256258
@pytest.mark.parametrize("act_order", [True, False])
257-
def test_marlin_moe_mmm(
259+
@pytest.mark.parametrize("num_bits", [4, 8])
260+
def test_single_marlin_moe_multiply(
258261
m: int,
259262
n: int,
260263
k: int,
261264
e: int,
262265
topk: int,
263266
group_size: int,
264267
act_order: bool,
268+
num_bits: int,
265269
):
266270
if topk > e:
267271
return
@@ -273,7 +277,8 @@ def test_marlin_moe_mmm(
273277
if group_size == k:
274278
return
275279

276-
quant_type = scalar_types.uint4b8
280+
quant_type = (scalar_types.uint4b8
281+
if num_bits == 4 else scalar_types.uint8b128)
277282
dtype = torch.float16
278283
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
279284
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
@@ -308,7 +313,8 @@ def test_marlin_moe_mmm(
308313
g_idx,
309314
sort_indices,
310315
topk,
311-
renormalize=False)
316+
renormalize=False,
317+
num_bits=num_bits)
312318
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
313319

314320
assert compute_max_diff(marlin_output, torch_output) < 1e-2
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
22
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
3-
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
3+
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
4+
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main

tests/weight_loading/run_model_weight_loading_test.sh

100644100755
File mode changed.

vllm/_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
559559
num_bits: int) -> torch.Tensor:
560560
num_experts = b_q_weight.shape[0]
561561
assert size_k % 16 == 0
562-
output = torch.empty((num_experts, size_k // 16, size_n * 2),
562+
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
563563
device=b_q_weight.device,
564564
dtype=b_q_weight.dtype)
565565
for e in range(num_experts):

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,21 @@
77
from vllm import _custom_ops as ops
88
from vllm.model_executor.layers.fused_moe.fused_moe import (
99
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
10+
from vllm.scalar_type import scalar_types
1011

1112

1213
def single_marlin_moe(
13-
hidden_states: torch.Tensor,
14-
w: torch.Tensor,
15-
scales: torch.Tensor,
16-
gating_output: torch.Tensor,
17-
g_idx: torch.Tensor,
18-
perm: torch.Tensor,
19-
topk: int,
20-
renormalize: bool,
21-
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor:
14+
hidden_states: torch.Tensor,
15+
w: torch.Tensor,
16+
scales: torch.Tensor,
17+
gating_output: torch.Tensor,
18+
g_idx: torch.Tensor,
19+
perm: torch.Tensor,
20+
topk: int,
21+
renormalize: bool,
22+
override_config: Optional[Dict[str, Any]] = None,
23+
num_bits: int = 8,
24+
) -> torch.Tensor:
2225
"""
2326
This function computes the multiplication of hidden_states with expert
2427
weights used in Marlin MoE, using weights w and top-k gating mechanism.
@@ -36,6 +39,7 @@ def single_marlin_moe(
3639
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
3740
- override_config (Optional[Dict[str, Any]]): Optional override
3841
for the kernel configuration.
42+
- num_bits (bool): The number of bits in expert weights quantization.
3943
4044
Returns:
4145
- torch.Tensor: The output tensor after applying the MoE layer.
@@ -48,10 +52,11 @@ def single_marlin_moe(
4852
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
4953
assert w.is_contiguous(), "Expert weights must be contiguous"
5054
assert hidden_states.dtype == torch.float16
55+
assert num_bits in [4, 8]
5156

5257
M, K = hidden_states.shape
5358
E = w.shape[0]
54-
N = w.shape[2] // 2
59+
N = w.shape[2] // (num_bits // 2)
5560

5661
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
5762
renormalize)
@@ -76,10 +81,13 @@ def single_marlin_moe(
7681
device="cuda",
7782
requires_grad=False)
7883

84+
scalar_type = (scalar_types.uint4b8
85+
if num_bits == 4 else scalar_types.uint8b128)
86+
7987
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
8088
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
81-
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True,
82-
False)
89+
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
90+
block_size_m, True, False)
8391

8492
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
8593

@@ -98,6 +106,7 @@ def fused_marlin_moe(
98106
override_config: Optional[Dict[str, Any]] = None,
99107
w1_scale: Optional[torch.Tensor] = None,
100108
w2_scale: Optional[torch.Tensor] = None,
109+
num_bits: int = 8,
101110
) -> torch.Tensor:
102111
"""
103112
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -122,6 +131,7 @@ def fused_marlin_moe(
122131
w1.
123132
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
124133
w2.
134+
- num_bits (bool): The number of bits in expert weights quantization.
125135
126136
Returns:
127137
- torch.Tensor: The output tensor after applying the MoE layer.
@@ -131,13 +141,14 @@ def fused_marlin_moe(
131141
0], "Number of tokens mismatch"
132142
assert hidden_states.shape[
133143
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
134-
assert hidden_states.shape[
135-
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
144+
assert hidden_states.shape[1] == w2.shape[2] // (
145+
num_bits // 2), "Hidden size mismatch w2"
136146
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
137147
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
138148
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
139149
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
140150
assert hidden_states.dtype == torch.float16
151+
assert num_bits in [4, 8]
141152

142153
M, K = hidden_states.shape
143154
E = w1.shape[0]
@@ -165,6 +176,9 @@ def fused_marlin_moe(
165176
device="cuda",
166177
requires_grad=False)
167178

179+
scalar_type = (scalar_types.uint4b8
180+
if num_bits == 4 else scalar_types.uint8b128)
181+
168182
intermediate_cache2 = torch.empty(
169183
(M * topk_ids.shape[1], N),
170184
device=hidden_states.device,
@@ -181,6 +195,7 @@ def fused_marlin_moe(
181195
g_idx1,
182196
perm1,
183197
workspace,
198+
scalar_type,
184199
M,
185200
2 * N,
186201
K,
@@ -204,6 +219,7 @@ def fused_marlin_moe(
204219
g_idx2,
205220
perm2,
206221
workspace,
222+
scalar_type,
207223
M,
208224
K,
209225
N,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
445445
if renormalize:
446446
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
447447

448-
return topk_weights, topk_ids.to(torch.int32)
448+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
449449

450450

451451
def get_config_dtype_str(dtype: torch.dtype,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from vllm import _custom_ops as ops
88
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
9+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
10+
WNA16_SUPPORTED_BITS)
911
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1012
CompressionFormat)
1113
from vllm.model_executor.utils import set_weight_attrs
@@ -38,10 +40,11 @@ def __init__(
3840

3941
if not (self.quant_config.quant_format
4042
== CompressionFormat.pack_quantized.value
41-
and self.num_bits == 4):
43+
and self.num_bits in WNA16_SUPPORTED_BITS):
4244
raise ValueError("For Fused MoE layers, only ",
4345
f"{CompressionFormat.pack_quantized.value} ",
44-
"is supported for 4 bits")
46+
"is supported for the following bits: ",
47+
f"{WNA16_SUPPORTED_BITS}")
4548

4649
def create_weights(self, layer: torch.nn.Module, num_experts: int,
4750
hidden_size: int, intermediate_size: int,
@@ -292,4 +295,5 @@ def apply(
292295
topk_ids,
293296
w1_scale=layer.w13_weight_scale,
294297
w2_scale=layer.w2_weight_scale,
298+
num_bits=self.num_bits,
295299
)

0 commit comments

Comments
 (0)