@@ -72,9 +72,26 @@ class OperatorConfig(NamedTuple):
7272 valid_func_list : List [Callable ] = []
7373
7474
75+ class TorchBaseConfig (BaseConfig ):
76+ # re-write func _get_op_name_op_type_config to fallback op_type with string
77+ # because there are some special op_types for IPEX backend: `Linear&Relu`, `Linear&add`, ...
78+ def _get_op_name_op_type_config (self ):
79+ op_type_config_dict = dict ()
80+ op_name_config_dict = dict ()
81+ for name , config in self .local_config .items ():
82+ if self ._is_op_type (name ):
83+ # Convert the Callable to String.
84+ new_name = self ._op_type_to_str (name )
85+ op_type_config_dict [new_name ] = config
86+ else :
87+ op_name_config_dict [name ] = config
88+ op_type_config_dict [name ] = config
89+ return op_type_config_dict , op_name_config_dict
90+
91+
7592######################## RNT Config ###############################
7693@register_config (framework_name = FRAMEWORK_NAME , algo_name = RTN , priority = PRIORITY_RTN )
77- class RTNConfig (BaseConfig ):
94+ class RTNConfig (TorchBaseConfig ):
7895 """Config class for round-to-nearest weight-only quantization."""
7996
8097 name = RTN
@@ -242,7 +259,7 @@ def get_default_double_quant_config(type="BNB_NF4"):
242259
243260######################## GPTQ Config ###############################
244261@register_config (framework_name = FRAMEWORK_NAME , algo_name = GPTQ , priority = PRIORITY_GPTQ )
245- class GPTQConfig (BaseConfig ):
262+ class GPTQConfig (TorchBaseConfig ):
246263 """Config class for GPTQ.
247264
248265 GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
@@ -397,7 +414,7 @@ def get_default_gptq_config(processor_type: Optional[Union[str, torch_utils.Proc
397414
398415######################## AWQ Config ###############################
399416@register_config (framework_name = FRAMEWORK_NAME , algo_name = AWQ , priority = PRIORITY_AWQ )
400- class AWQConfig (BaseConfig ):
417+ class AWQConfig (TorchBaseConfig ):
401418 """Config class for AWQ.
402419
403420 AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
@@ -539,7 +556,7 @@ def get_default_awq_config() -> AWQConfig:
539556
540557######################## TEQ Config ###############################
541558@register_config (framework_name = FRAMEWORK_NAME , algo_name = TEQ , priority = PRIORITY_TEQ )
542- class TEQConfig (BaseConfig ):
559+ class TEQConfig (TorchBaseConfig ):
543560 """Config class for TEQ.
544561
545562 TEQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
@@ -677,7 +694,7 @@ def get_default_teq_config() -> TEQConfig:
677694
678695######################## AUTOROUND Config ###############################
679696@register_config (framework_name = FRAMEWORK_NAME , algo_name = AUTOROUND , priority = PRIORITY_AUTOROUND )
680- class AutoRoundConfig (BaseConfig ):
697+ class AutoRoundConfig (TorchBaseConfig ):
681698 """Config class for AUTOROUND.
682699
683700 AUTOROUND: Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs.
@@ -827,7 +844,7 @@ def get_default_AutoRound_config(processor_type: Optional[Union[str, torch_utils
827844
828845######################## MX Config ###############################
829846@register_config (framework_name = FRAMEWORK_NAME , algo_name = MX_QUANT )
830- class MXQuantConfig (BaseConfig ):
847+ class MXQuantConfig (TorchBaseConfig ):
831848 """Config class for MX quantization."""
832849
833850 supported_configs : List [OperatorConfig ] = []
@@ -940,7 +957,7 @@ def get_default_mx_config() -> MXQuantConfig:
940957
941958######################## Dynamic Quant Config ###############################
942959@register_config (framework_name = FRAMEWORK_NAME , algo_name = PT2E_DYNAMIC_QUANT )
943- class DynamicQuantConfig (BaseConfig ):
960+ class DynamicQuantConfig (TorchBaseConfig ):
944961 """Config class for dynamic quantization."""
945962
946963 name = PT2E_DYNAMIC_QUANT
@@ -1014,7 +1031,7 @@ def get_default_dynamic_config() -> DynamicQuantConfig:
10141031
10151032######################## Static Quant Config ###############################
10161033@register_config (framework_name = FRAMEWORK_NAME , algo_name = STATIC_QUANT )
1017- class StaticQuantConfig (BaseConfig ):
1034+ class StaticQuantConfig (TorchBaseConfig ):
10181035 """Config class for static quantization."""
10191036
10201037 name = STATIC_QUANT
@@ -1103,7 +1120,7 @@ def get_default_static_config() -> StaticQuantConfig:
11031120
11041121######################## Smooth Quant Config ###############################
11051122@register_config (framework_name = FRAMEWORK_NAME , algo_name = SMOOTH_QUANT )
1106- class SmoothQuantConfig (BaseConfig ):
1123+ class SmoothQuantConfig (TorchBaseConfig ):
11071124 """Config class for smooth quantization."""
11081125
11091126 name = SMOOTH_QUANT
@@ -1217,7 +1234,7 @@ def get_default_sq_config() -> SmoothQuantConfig:
12171234
12181235######################## HQQ Config ###############################
12191236@register_config (framework_name = FRAMEWORK_NAME , algo_name = HQQ , priority = PRIORITY_HQQ )
1220- class HQQConfig (BaseConfig ):
1237+ class HQQConfig (TorchBaseConfig ):
12211238 # Half-Quadratic Quantization (HQQ), more details:
12221239 # Blog: https://mobiusml.github.io/hqq_blog/
12231240 # Code: https://github.com/mobiusml/hqq
@@ -1298,7 +1315,7 @@ def get_default_hqq_config() -> HQQConfig:
12981315
12991316######################## FP8 Config ###############################
13001317@register_config (framework_name = FRAMEWORK_NAME , algo_name = FP8_QUANT )
1301- class FP8Config (BaseConfig ):
1318+ class FP8Config (TorchBaseConfig ):
13021319 """Config class for FP8 quantization."""
13031320
13041321 name = FP8_QUANT
@@ -1393,7 +1410,7 @@ def get_default_fp8_config_set() -> FP8Config:
13931410
13941411######################## MixPrecision Config ###############################
13951412@register_config (framework_name = FRAMEWORK_NAME , algo_name = MIX_PRECISION )
1396- class MixPrecisionConfig (BaseConfig ):
1413+ class MixPrecisionConfig (TorchBaseConfig ):
13971414 """Config class for mix-precision."""
13981415
13991416 name = MIX_PRECISION
0 commit comments