77
88import torch
99from compressed_tensors import CompressionFormat
10- from compressed_tensors .quantization import ActivationOrdering , QuantizationStrategy
10+ from compressed_tensors .quantization import (
11+ ActivationOrdering ,
12+ QuantizationArgs ,
13+ QuantizationStrategy ,
14+ )
1115from torch .nn .parameter import Parameter
1216
1317import vllm .envs as envs
@@ -142,10 +146,26 @@ def get_moe_method(
142146 # are supported + check if the layer is being ignored.
143147 weight_quant = scheme_dict .get ("weights" )
144148 input_quant = scheme_dict .get ("input_activations" )
149+ format = scheme_dict .get ("format" )
145150
146151 if quant_config ._is_wNa16_group_channel (weight_quant , input_quant ):
147152 # group_size=None means channelwise
148153 group_size = weight_quant .group_size or - 1
154+
155+ valid_format_and_bits = (
156+ weight_quant .num_bits in WNA16_SUPPORTED_BITS
157+ and format == CompressionFormat .pack_quantized .value
158+ )
159+
160+ if not valid_format_and_bits :
161+ raise ValueError (
162+ "For Fused MoE layers, only format: " ,
163+ f"{ CompressionFormat .pack_quantized .value } " ,
164+ f" and bits: { WNA16_SUPPORTED_BITS } is supported " ,
165+ f"but got format: { CompressionFormat .pack_quantized .value } "
166+ f" and bits: { weight_quant .num_bits } " ,
167+ )
168+
149169 # Prefer to use the MarlinMoE kernel when it is supported.
150170 if (
151171 not check_moe_marlin_supports_layer (layer , group_size )
@@ -161,12 +181,12 @@ def get_moe_method(
161181 )
162182 logger .info_once ("Using CompressedTensorsWNA16MoEMethod" )
163183 return CompressedTensorsWNA16MoEMethod (
164- quant_config , layer .moe_config , layer_name
184+ weight_quant , input_quant , layer .moe_config
165185 )
166186 else :
167187 logger .info_once ("Using CompressedTensorsWNA16MarlinMoEMethod" )
168188 return CompressedTensorsWNA16MarlinMoEMethod (
169- quant_config , layer .moe_config , layer_name
189+ weight_quant , input_quant , layer .moe_config
170190 )
171191 elif quant_config ._is_fp4a4_nvfp4 (weight_quant , input_quant ):
172192 return CompressedTensorsW4A4Nvfp4MoEMethod (layer .moe_config , layer_name )
@@ -176,15 +196,15 @@ def get_moe_method(
176196 or quant_config ._is_fp8_w8a8 (weight_quant , input_quant )
177197 ):
178198 return CompressedTensorsW8A8Fp8MoEMethod (
179- quant_config , layer .moe_config , layer_name
199+ weight_quant , input_quant , layer .moe_config
180200 )
181201 elif quant_config ._is_dynamic_token_w8a8 (weight_quant , input_quant ):
182202 return CompressedTensorsW8A8Int8MoEMethod (
183- quant_config , layer .moe_config , layer_name
203+ weight_quant , input_quant , layer .moe_config
184204 )
185205 elif quant_config ._is_dynamic_token_w4a8_int (weight_quant , input_quant ):
186206 return CompressedTensorsW4A8Int8MoEMethod (
187- quant_config , layer .moe_config , layer_name
207+ weight_quant , input_quant , layer .moe_config
188208 )
189209 else :
190210 raise RuntimeError (
@@ -650,17 +670,19 @@ def apply(
650670class CompressedTensorsW8A8Fp8MoEMethod (CompressedTensorsMoEMethod ):
651671 def __init__ (
652672 self ,
653- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
673+ weight_quant : QuantizationArgs ,
674+ input_quant : QuantizationArgs ,
654675 moe : FusedMoEConfig ,
655676 layer_name : str | None = None ,
656677 ):
657- super ().__init__ (moe )
658- self .quant_config = quant_config
659- self .weight_quant = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
660- self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
661- "input_activations"
678+ from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors import ( # noqa: E501
679+ CompressedTensorsConfig ,
662680 )
663681
682+ super ().__init__ (moe )
683+ self .weight_quant = weight_quant
684+ self .input_quant = input_quant
685+
664686 per_tensor = (
665687 self .weight_quant .strategy == QuantizationStrategy .TENSOR
666688 and self .input_quant .strategy == QuantizationStrategy .TENSOR
@@ -698,11 +720,13 @@ def __init__(
698720 self .rocm_aiter_moe_enabled = rocm_aiter_ops .is_fused_moe_enabled ()
699721
700722 # cutlass path
701- self .is_fp8_w8a8_sm100 = quant_config ._is_fp8_w8a8_sm100 (
723+ self .is_fp8_w8a8_sm100 = CompressedTensorsConfig ._is_fp8_w8a8_sm100 (
702724 self .weight_quant , self .input_quant
703725 )
704726 self .use_cutlass = not self .block_quant and (
705- quant_config ._is_fp8_w8a8_sm90 (self .weight_quant , self .input_quant )
727+ CompressedTensorsConfig ._is_fp8_w8a8_sm90 (
728+ self .weight_quant , self .input_quant
729+ )
706730 or self .is_fp8_w8a8_sm100
707731 )
708732 self .disable_expert_map = False
@@ -1261,16 +1285,14 @@ def supports_eplb(self) -> bool:
12611285class CompressedTensorsW8A8Int8MoEMethod (CompressedTensorsMoEMethod ):
12621286 def __init__ (
12631287 self ,
1264- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
1288+ weight_quant : QuantizationArgs ,
1289+ input_quant : QuantizationArgs ,
12651290 moe : FusedMoEConfig ,
12661291 layer_name : str | None = None ,
12671292 ):
12681293 super ().__init__ (moe )
1269- self .quant_config = quant_config
1270- self .weight_quant = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
1271- self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
1272- "input_activations"
1273- )
1294+ self .weight_quant = weight_quant
1295+ self .input_quant = input_quant
12741296
12751297 per_channel = (
12761298 self .weight_quant .strategy == QuantizationStrategy .CHANNEL
@@ -1414,36 +1436,27 @@ def apply(
14141436class CompressedTensorsWNA16MarlinMoEMethod (CompressedTensorsMoEMethod ):
14151437 def __init__ (
14161438 self ,
1417- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
1439+ weight_quant : QuantizationArgs ,
1440+ input_quant : QuantizationArgs | None ,
14181441 moe : FusedMoEConfig ,
14191442 layer_name : str | None = None ,
14201443 ):
14211444 super ().__init__ (moe )
1422- self .quant_config = quant_config
1423- # TODO: @dsikka: refactor this to use schemes as other kernels
1424- # are supported + check if the layer is being ignored.
1425- config = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
1426- self .num_bits = config .num_bits
1427- self .packed_factor = 32 // config .num_bits
1428- self .strategy = config .strategy
1429- self .group_size = config .group_size
1430- self .actorder = config .actorder
1431- self .layer_name = layer_name
1432- self .marlin_input_dtype = get_marlin_input_dtype (layer_name )
1433- assert config .symmetric , "Only symmetric quantization is supported for MoE"
1445+ self .weight_quant = weight_quant
1446+ self .input_quant = input_quant
1447+ assert weight_quant .symmetric , (
1448+ "Only symmetric quantization is supported for MoE"
1449+ )
1450+ # Extract properties from weight_quant
1451+ self .num_bits = weight_quant .num_bits
1452+ self .packed_factor = 32 // weight_quant .num_bits
1453+ self .strategy = weight_quant .strategy
1454+ self .group_size = weight_quant .group_size
1455+ self .actorder = weight_quant .actorder
14341456
1435- if not (
1436- self .quant_config .quant_format == CompressionFormat .pack_quantized .value
1437- and self .num_bits in WNA16_SUPPORTED_BITS
1438- ):
1439- raise ValueError (
1440- "For Fused MoE layers, only " ,
1441- f"{ CompressionFormat .pack_quantized .value } " ,
1442- "is supported for the following bits: " ,
1443- f"{ WNA16_SUPPORTED_BITS } " ,
1444- )
14451457 self .quant_type = WNA16_SUPPORTED_TYPES_MAP [self .num_bits ]
14461458 self .use_marlin = True
1459+ self .marlin_input_dtype = get_marlin_input_dtype (layer_name )
14471460
14481461 def create_weights (
14491462 self ,
@@ -1812,35 +1825,26 @@ def apply(
18121825class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
18131826 def __init__ (
18141827 self ,
1815- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
1828+ weight_quant : QuantizationArgs ,
1829+ input_quant : QuantizationArgs | None ,
18161830 moe : FusedMoEConfig ,
18171831 layer_name : str | None = None ,
18181832 ):
18191833 super ().__init__ (moe )
1820- self .quant_config = quant_config
1821- # TODO: @dsikka: refactor this to use schemes as other kernels
1822- # are supported + check if the layer is being ignored.
1823- config = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
1824- self .num_bits = config .num_bits
1825- self .packed_factor = 32 // config .num_bits
1826- self .strategy = config .strategy
1834+ self .weight_quant = weight_quant
1835+ self .input_quant = input_quant
1836+ # Extract properties from weight_quant
1837+ self .num_bits = weight_quant .num_bits
1838+ self .packed_factor = 32 // weight_quant .num_bits
1839+ self .strategy = weight_quant .strategy
18271840 # channelwise is not supported by this kernel
1828- assert config .strategy == "group"
1829- self .group_size = config .group_size
1841+ assert weight_quant .strategy == "group"
1842+ self .group_size = weight_quant .group_size
18301843 # grouped actorder isn't supported by this kernel
1831- assert config .actorder != "group"
1832- assert config .symmetric , "Only symmetric quantization is supported for MoE"
1833-
1834- if not (
1835- self .quant_config .quant_format == CompressionFormat .pack_quantized .value
1836- and self .num_bits in WNA16_SUPPORTED_BITS
1837- ):
1838- raise ValueError (
1839- "For Fused MoE layers, only " ,
1840- f"{ CompressionFormat .pack_quantized .value } " ,
1841- "is supported for the following bits: " ,
1842- f"{ WNA16_SUPPORTED_BITS } " ,
1843- )
1844+ assert weight_quant .actorder != "group"
1845+ assert weight_quant .symmetric , (
1846+ "Only symmetric quantization is supported for MoE"
1847+ )
18441848
18451849 def create_weights (
18461850 self ,
@@ -2065,28 +2069,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
20652069
20662070 def __init__ (
20672071 self ,
2068- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
2072+ weight_quant : QuantizationArgs ,
2073+ input_quant : QuantizationArgs ,
20692074 moe : FusedMoEConfig ,
20702075 layer_name : str | None = None ,
20712076 ):
20722077 super ().__init__ (moe )
20732078 self .has_bias = self .moe .has_bias
2074- self .quant_config = quant_config
2079+ self .weight_quant = weight_quant
2080+ self .input_quant = input_quant
20752081
20762082 # Validate scheme: weights=W4 (channel or group),
20772083 # activations=dynamic TOKEN (A8)
2078- wq = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
2079- aq = self .quant_config .target_scheme_map ["Linear" ].get ("input_activations" )
20802084
20812085 # Must be dynamic per-token activations
2082- if aq .strategy != QuantizationStrategy .TOKEN or not aq .dynamic :
2086+ if (
2087+ input_quant .strategy != QuantizationStrategy .TOKEN
2088+ or not input_quant .dynamic
2089+ ):
20832090 raise ValueError (
20842091 "W4A8-int MoE needs dynamic per-token activation quantization."
20852092 )
20862093
20872094 # Weight can be channel-wise (group_size=None) or group-wise
2088- self .group_size = wq .group_size if (wq .group_size is not None ) else - 1
2089- if wq .num_bits != 4 :
2095+ self .group_size = (
2096+ weight_quant .group_size if (weight_quant .group_size is not None ) else - 1
2097+ )
2098+ if weight_quant .num_bits != 4 :
20902099 raise ValueError ("This method only supports 4-bit weights (num_bits=4)." )
20912100
20922101 # CPU only
0 commit comments