Skip to content

Commit b294e28

Browse files
HDCharlesIsotr0py
andauthored
[refactor] CTMoEMethods to use QuantizationArgs (vllm-project#28871)
Signed-off-by: HDCharles <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent 787b84a commit b294e28

File tree

2 files changed

+86
-75
lines changed

2 files changed

+86
-75
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,10 @@ def get_scheme_dict(
767767
targets=self.target_scheme_map.keys(),
768768
fused_mapping=self.packed_modules_mapping,
769769
)
770-
771-
return self.target_scheme_map[matched_target]
770+
scheme_dict = self.target_scheme_map[matched_target]
771+
if scheme_dict.get("format") is None:
772+
scheme_dict["format"] = self.quant_format
773+
return scheme_dict
772774

773775
return None
774776

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 82 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
import torch
99
from compressed_tensors import CompressionFormat
10-
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
10+
from compressed_tensors.quantization import (
11+
ActivationOrdering,
12+
QuantizationArgs,
13+
QuantizationStrategy,
14+
)
1115
from torch.nn.parameter import Parameter
1216

1317
import vllm.envs as envs
@@ -142,10 +146,26 @@ def get_moe_method(
142146
# are supported + check if the layer is being ignored.
143147
weight_quant = scheme_dict.get("weights")
144148
input_quant = scheme_dict.get("input_activations")
149+
format = scheme_dict.get("format")
145150

146151
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
147152
# group_size=None means channelwise
148153
group_size = weight_quant.group_size or -1
154+
155+
valid_format_and_bits = (
156+
weight_quant.num_bits in WNA16_SUPPORTED_BITS
157+
and format == CompressionFormat.pack_quantized.value
158+
)
159+
160+
if not valid_format_and_bits:
161+
raise ValueError(
162+
"For Fused MoE layers, only format: ",
163+
f"{CompressionFormat.pack_quantized.value} ",
164+
f" and bits: {WNA16_SUPPORTED_BITS} is supported ",
165+
f"but got format: {CompressionFormat.pack_quantized.value} "
166+
f" and bits: {weight_quant.num_bits}",
167+
)
168+
149169
# Prefer to use the MarlinMoE kernel when it is supported.
150170
if (
151171
not check_moe_marlin_supports_layer(layer, group_size)
@@ -161,12 +181,12 @@ def get_moe_method(
161181
)
162182
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
163183
return CompressedTensorsWNA16MoEMethod(
164-
quant_config, layer.moe_config, layer_name
184+
weight_quant, input_quant, layer.moe_config
165185
)
166186
else:
167187
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
168188
return CompressedTensorsWNA16MarlinMoEMethod(
169-
quant_config, layer.moe_config, layer_name
189+
weight_quant, input_quant, layer.moe_config
170190
)
171191
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
172192
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
@@ -176,15 +196,15 @@ def get_moe_method(
176196
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
177197
):
178198
return CompressedTensorsW8A8Fp8MoEMethod(
179-
quant_config, layer.moe_config, layer_name
199+
weight_quant, input_quant, layer.moe_config
180200
)
181201
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
182202
return CompressedTensorsW8A8Int8MoEMethod(
183-
quant_config, layer.moe_config, layer_name
203+
weight_quant, input_quant, layer.moe_config
184204
)
185205
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
186206
return CompressedTensorsW4A8Int8MoEMethod(
187-
quant_config, layer.moe_config, layer_name
207+
weight_quant, input_quant, layer.moe_config
188208
)
189209
else:
190210
raise RuntimeError(
@@ -650,17 +670,19 @@ def apply(
650670
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
651671
def __init__(
652672
self,
653-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
673+
weight_quant: QuantizationArgs,
674+
input_quant: QuantizationArgs,
654675
moe: FusedMoEConfig,
655676
layer_name: str | None = None,
656677
):
657-
super().__init__(moe)
658-
self.quant_config = quant_config
659-
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
660-
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
661-
"input_activations"
678+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
679+
CompressedTensorsConfig,
662680
)
663681

682+
super().__init__(moe)
683+
self.weight_quant = weight_quant
684+
self.input_quant = input_quant
685+
664686
per_tensor = (
665687
self.weight_quant.strategy == QuantizationStrategy.TENSOR
666688
and self.input_quant.strategy == QuantizationStrategy.TENSOR
@@ -698,11 +720,13 @@ def __init__(
698720
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
699721

700722
# cutlass path
701-
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
723+
self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
702724
self.weight_quant, self.input_quant
703725
)
704726
self.use_cutlass = not self.block_quant and (
705-
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
727+
CompressedTensorsConfig._is_fp8_w8a8_sm90(
728+
self.weight_quant, self.input_quant
729+
)
706730
or self.is_fp8_w8a8_sm100
707731
)
708732
self.disable_expert_map = False
@@ -1261,16 +1285,14 @@ def supports_eplb(self) -> bool:
12611285
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
12621286
def __init__(
12631287
self,
1264-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
1288+
weight_quant: QuantizationArgs,
1289+
input_quant: QuantizationArgs,
12651290
moe: FusedMoEConfig,
12661291
layer_name: str | None = None,
12671292
):
12681293
super().__init__(moe)
1269-
self.quant_config = quant_config
1270-
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
1271-
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
1272-
"input_activations"
1273-
)
1294+
self.weight_quant = weight_quant
1295+
self.input_quant = input_quant
12741296

12751297
per_channel = (
12761298
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
@@ -1414,36 +1436,27 @@ def apply(
14141436
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
14151437
def __init__(
14161438
self,
1417-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
1439+
weight_quant: QuantizationArgs,
1440+
input_quant: QuantizationArgs | None,
14181441
moe: FusedMoEConfig,
14191442
layer_name: str | None = None,
14201443
):
14211444
super().__init__(moe)
1422-
self.quant_config = quant_config
1423-
# TODO: @dsikka: refactor this to use schemes as other kernels
1424-
# are supported + check if the layer is being ignored.
1425-
config = self.quant_config.target_scheme_map["Linear"].get("weights")
1426-
self.num_bits = config.num_bits
1427-
self.packed_factor = 32 // config.num_bits
1428-
self.strategy = config.strategy
1429-
self.group_size = config.group_size
1430-
self.actorder = config.actorder
1431-
self.layer_name = layer_name
1432-
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
1433-
assert config.symmetric, "Only symmetric quantization is supported for MoE"
1445+
self.weight_quant = weight_quant
1446+
self.input_quant = input_quant
1447+
assert weight_quant.symmetric, (
1448+
"Only symmetric quantization is supported for MoE"
1449+
)
1450+
# Extract properties from weight_quant
1451+
self.num_bits = weight_quant.num_bits
1452+
self.packed_factor = 32 // weight_quant.num_bits
1453+
self.strategy = weight_quant.strategy
1454+
self.group_size = weight_quant.group_size
1455+
self.actorder = weight_quant.actorder
14341456

1435-
if not (
1436-
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
1437-
and self.num_bits in WNA16_SUPPORTED_BITS
1438-
):
1439-
raise ValueError(
1440-
"For Fused MoE layers, only ",
1441-
f"{CompressionFormat.pack_quantized.value} ",
1442-
"is supported for the following bits: ",
1443-
f"{WNA16_SUPPORTED_BITS}",
1444-
)
14451457
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
14461458
self.use_marlin = True
1459+
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
14471460

14481461
def create_weights(
14491462
self,
@@ -1812,35 +1825,26 @@ def apply(
18121825
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
18131826
def __init__(
18141827
self,
1815-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
1828+
weight_quant: QuantizationArgs,
1829+
input_quant: QuantizationArgs | None,
18161830
moe: FusedMoEConfig,
18171831
layer_name: str | None = None,
18181832
):
18191833
super().__init__(moe)
1820-
self.quant_config = quant_config
1821-
# TODO: @dsikka: refactor this to use schemes as other kernels
1822-
# are supported + check if the layer is being ignored.
1823-
config = self.quant_config.target_scheme_map["Linear"].get("weights")
1824-
self.num_bits = config.num_bits
1825-
self.packed_factor = 32 // config.num_bits
1826-
self.strategy = config.strategy
1834+
self.weight_quant = weight_quant
1835+
self.input_quant = input_quant
1836+
# Extract properties from weight_quant
1837+
self.num_bits = weight_quant.num_bits
1838+
self.packed_factor = 32 // weight_quant.num_bits
1839+
self.strategy = weight_quant.strategy
18271840
# channelwise is not supported by this kernel
1828-
assert config.strategy == "group"
1829-
self.group_size = config.group_size
1841+
assert weight_quant.strategy == "group"
1842+
self.group_size = weight_quant.group_size
18301843
# grouped actorder isn't supported by this kernel
1831-
assert config.actorder != "group"
1832-
assert config.symmetric, "Only symmetric quantization is supported for MoE"
1833-
1834-
if not (
1835-
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
1836-
and self.num_bits in WNA16_SUPPORTED_BITS
1837-
):
1838-
raise ValueError(
1839-
"For Fused MoE layers, only ",
1840-
f"{CompressionFormat.pack_quantized.value} ",
1841-
"is supported for the following bits: ",
1842-
f"{WNA16_SUPPORTED_BITS}",
1843-
)
1844+
assert weight_quant.actorder != "group"
1845+
assert weight_quant.symmetric, (
1846+
"Only symmetric quantization is supported for MoE"
1847+
)
18441848

18451849
def create_weights(
18461850
self,
@@ -2065,28 +2069,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
20652069

20662070
def __init__(
20672071
self,
2068-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
2072+
weight_quant: QuantizationArgs,
2073+
input_quant: QuantizationArgs,
20692074
moe: FusedMoEConfig,
20702075
layer_name: str | None = None,
20712076
):
20722077
super().__init__(moe)
20732078
self.has_bias = self.moe.has_bias
2074-
self.quant_config = quant_config
2079+
self.weight_quant = weight_quant
2080+
self.input_quant = input_quant
20752081

20762082
# Validate scheme: weights=W4 (channel or group),
20772083
# activations=dynamic TOKEN (A8)
2078-
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
2079-
aq = self.quant_config.target_scheme_map["Linear"].get("input_activations")
20802084

20812085
# Must be dynamic per-token activations
2082-
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
2086+
if (
2087+
input_quant.strategy != QuantizationStrategy.TOKEN
2088+
or not input_quant.dynamic
2089+
):
20832090
raise ValueError(
20842091
"W4A8-int MoE needs dynamic per-token activation quantization."
20852092
)
20862093

20872094
# Weight can be channel-wise (group_size=None) or group-wise
2088-
self.group_size = wq.group_size if (wq.group_size is not None) else -1
2089-
if wq.num_bits != 4:
2095+
self.group_size = (
2096+
weight_quant.group_size if (weight_quant.group_size is not None) else -1
2097+
)
2098+
if weight_quant.num_bits != 4:
20902099
raise ValueError("This method only supports 4-bit weights (num_bits=4).")
20912100

20922101
# CPU only

0 commit comments

Comments
 (0)