Skip to content

Commit db6164a

Browse files
authored
Add fallback support for HQQ (#1848)
Signed-off-by: yiliu30 <[email protected]>
1 parent 12b8f41 commit db6164a

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,8 @@ def _parse_hqq_configs_mapping(self, configs_mapping):
149149
if quant_config.skip_lm_head and "lm_head" in op_name:
150150
logger.warning("Skip quantizing %s due to `skip_lm_head` is True.", op_name)
151151
continue
152+
if quant_config is not None and quant_config.dtype == "fp32":
153+
logger.warning("Fallback %s.", op_name)
154+
continue
152155
qconfig_mapping[op_name] = self._convert_hqq_module_config(quant_config)
153156
return qconfig_mapping

neural_compressor/torch/quantization/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,7 @@ class HQQConfig(BaseConfig):
11791179

11801180
def __init__(
11811181
self,
1182+
dtype: str = "int",
11821183
bits: int = 4,
11831184
group_size: int = 64,
11841185
quant_zero: bool = True,
@@ -1188,6 +1189,7 @@ def __init__(
11881189
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
11891190
):
11901191
super().__init__(white_list=white_list)
1192+
self.dtype = dtype
11911193
self.bits = bits
11921194
self.group_size = group_size
11931195
self.quant_zero = quant_zero

test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,25 @@ def test_hqq_quant(self, force_use_cpu, force_not_half):
8787
q_label_1.eq(q_label_2)
8888
), "The results of calling `convert` + `prepare` and calling `quantize` should be equal."
8989

90+
def test_hqq_fallback(self, force_use_cpu, force_not_half):
91+
from neural_compressor.torch.quantization import HQQConfig, convert, prepare
92+
93+
class ToyModel(torch.nn.Module):
94+
def __init__(self):
95+
super().__init__()
96+
self.fc1 = torch.nn.Linear(128, 1024)
97+
self.fc2 = torch.nn.Linear(1024, 512)
98+
99+
def forward(self, x):
100+
x = self.fc1(x)
101+
x = self.fc2(x)
102+
return x
103+
104+
quant_config = HQQConfig().set_local("fc1", HQQConfig(dtype="fp32"))
105+
qmodel = convert(prepare(model=ToyModel(), quant_config=quant_config))
106+
assert type(qmodel.fc1).__name__ == torch.nn.Linear.__name__, f"Expect fallback fc1, but get {type(qmodel.fc1)}"
107+
assert type(qmodel.fc2).__name__ != torch.nn.Linear.__name__, f"Expect quantize fc2, but get {type(qmodel.fc2)}"
108+
90109
@pytest.mark.parametrize(
91110
"nbits, group_size, quant_zero, quant_scale, scale_quant_group_size",
92111
[

0 commit comments

Comments
 (0)