Skip to content

Commit 1a45090

Browse files
Migrate RTN, HQQ and AWQ to Torch new 3.x API (#1765)
Migrate RTN, HQQ and AWQ to Torch new 3.x API --------- Signed-off-by: yuwenzho <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 84d7055 commit 1a45090

File tree

13 files changed

+778
-410
lines changed

13 files changed

+778
-410
lines changed

neural_compressor/torch/algorithms/base_algorithm.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,18 @@ def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
8585
Returns:
8686
A quantized model.
8787
"""
88+
model = self.prepare(model, *args, **kwargs)
89+
8890
run_fn = kwargs.get("run_fn", None)
89-
run_args = kwargs.get("run_args", None)
90-
assert run_fn is not None, (
91-
"Can't find run_func. Please provide run_func to quantize API "
92-
"or overwrite quantize member function in your Quantizer class."
93-
)
91+
if run_fn is not None:
92+
run_args = kwargs.get("run_args", None)
93+
if run_args:
94+
run_fn(model, *run_args)
95+
else:
96+
run_fn(model)
9497

95-
model = self.prepare(model, *args, **kwargs)
96-
if run_args:
97-
run_fn(model, *run_args)
98-
else:
99-
run_fn(model)
10098
model = self.convert(model, *args, **kwargs)
99+
101100
return model
102101

103102
def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover

neural_compressor/torch/algorithms/static_quant/static_quant.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, quant_config: OrderedDict = {}):
5353
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
5454
"""
5555
super().__init__(quant_config)
56+
self.user_cfg = OrderedDict()
5657

5758
def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
5859
"""Prepares a given model for quantization.
@@ -71,7 +72,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
7172
model, example_inputs
7273
)
7374
# update json file in ipex_config_path; map ipex op_name to pt op_name
74-
user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
75+
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
7576
model.eval()
7677

7778
# Check save_qconf_summary part is a workaround for IPEX bug.
@@ -94,7 +95,6 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
9495
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
9596

9697
model.load_qconf_summary(qconf_summary=ipex_config_path)
97-
setattr(model, "user_cfg", user_cfg)
9898
return model
9999

100100
def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
@@ -110,16 +110,14 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
110110
"""
111111
from neural_compressor.torch.algorithms.static_quant import save
112112

113-
user_cfg = getattr(model, "user_cfg", OrderedDict())
114-
115113
model.save_qconf_summary(qconf_summary=ipex_config_path)
116114
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
117115

118116
with open(ipex_config_path, "r") as f:
119117
model.tune_cfg = json.load(f)
120118
model.ipex_config_path = ipex_config_path
121119

122-
dump_model_op_stats(user_cfg)
120+
dump_model_op_stats(self.user_cfg)
123121

124122
logger.info("Static quantization done.")
125123
model.ori_save = model.save

neural_compressor/torch/algorithms/weight_only/awq.py

Lines changed: 98 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
# Copied from neural_compressor/adaptor/torch_utils/awq.py
1616

1717
import copy
18+
from collections import OrderedDict
1819

1920
import torch
2021

22+
from neural_compressor.torch.algorithms import Quantizer
2123
from neural_compressor.torch.utils import get_device, logger
2224

2325
from .modules import MulLinear
@@ -26,13 +28,13 @@
2628
get_absorb_layers,
2729
get_block_prefix,
2830
get_example_input,
29-
get_hidden_states,
3031
get_module_input_output,
31-
model_forward,
32+
recover_forward,
33+
replace_forward,
3234
set_module,
3335
)
3436

35-
__all__ = ["awq_quantize"]
37+
__all__ = ["AWQQuantizer"]
3638

3739

3840
def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}):
@@ -113,15 +115,15 @@ def __init__(
113115
self,
114116
model,
115117
example_inputs=None,
116-
calib_func=None,
117118
dataloader=None,
118-
n_samples=128,
119119
data_type="int",
120120
bits=4,
121121
group_size=32,
122122
scheme="asym",
123123
use_full_range=False,
124124
weight_config={},
125+
total_block_args=[],
126+
total_block_kwargs=[],
125127
):
126128

127129
self.example_inputs = example_inputs
@@ -130,11 +132,9 @@ def __init__(
130132
assert dataloader is not None, "datalaoder or example_inputs is required."
131133
self.example_inputs = get_example_input(dataloader)
132134
self._move_model_and_data_to_device()
133-
# Step 1: get hidden states and kwargs of first block.
134-
self.total_block_args, self.total_block_kwargs = get_hidden_states(
135-
model, dataloader=dataloader, n_samples=n_samples, calib_func=calib_func
136-
)
137-
# Step 2: get block list and block prefix, number
135+
self.total_block_args = total_block_args
136+
self.total_block_kwargs = total_block_kwargs
137+
# get block list and block prefix, number
138138
self.block_prefix, self.block_num = get_block_prefix(model)
139139
self.block_list = fetch_module(model, self.block_prefix)
140140
self.data_type = data_type
@@ -429,14 +429,15 @@ def apply_quantize_with_clip(self, return_int=False):
429429
"""
430430
# apply quantization and clip
431431
logger.info("Quantizing the AWQ optimized fp32 model")
432-
from .rtn import rtn_quantize
432+
from .rtn import RTNQuantizer
433+
434+
rtn_quantizer = RTNQuantizer(quant_config=self.weight_config)
433435

434-
self.model = rtn_quantize(
436+
self.model = rtn_quantizer.quantize(
435437
self.model,
436-
num_bits=self.bits,
438+
bits=self.bits,
437439
group_size=self.group_size,
438440
scheme=self.scheme,
439-
weight_config=self.weight_config,
440441
return_int=return_int,
441442
use_full_range=self.use_full_range,
442443
)
@@ -492,78 +493,90 @@ def module_inference(self, model, inputs):
492493
return total_out
493494

494495

495-
@torch.no_grad()
496-
def awq_quantize(
497-
model,
498-
bits=4,
499-
group_size=32,
500-
scheme="asym",
501-
weight_config={},
502-
example_inputs=None,
503-
dataloader=None,
504-
n_samples=128,
505-
calib_func=None,
506-
use_auto_scale=True,
507-
use_mse_search=True,
508-
folding=False,
509-
return_int=False,
510-
use_full_range=False,
511-
data_type="int",
512-
):
513-
"""Quant the model with Activation-aware Weight quantization(AWQ) method.
496+
class AWQQuantizer(Quantizer):
497+
def __init__(self, quant_config: OrderedDict = {}):
498+
"""Init an AWQQuantizer object.
514499
515-
Args:
516-
model (torch.nn.Module): torch model.
517-
example_inputs: example_inputs.
518-
weight_config (dict, optional): contains all info required by AWQ. Defaults to {}.
519-
For example,
520-
weight_config={
521-
'fc2':
522-
{
523-
# 'absorb_layer': 'fc1',
524-
'bits': 4,
525-
'group_size': 32,
526-
'scheme': 'sym'
527-
}
528-
}
529-
absorb_dict (dict, optional): contains all absorb info required by AWQ.. Defaults to {}.
530-
For example,
531-
absorb_dict = {
532-
# 'absorb_layer': absorbed_layer
533-
'fc1': ['fc1', 'fc2', 'fc3']
534-
} # in this case, fc2 and fc3 need to share the same scale. fc1 is self absorbed.
535-
# self absorb module will replace with MulLinear, which contains torch.mul and module.
536-
n_samples: calibration sample number.
537-
use_auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True.
538-
use_mse_search (bool, optional): whether enable clip for weight by checking mse. Defaults to True.
539-
calib_func: a custom inference function to replace dataloader and iters.
540-
n_blocks: split model into block number to avoid OOM.
541-
return_int (bool, optional): Choose return fp32 or int32 model.
542-
Defaults to False.
543-
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
500+
Args:
501+
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
502+
"""
503+
super().__init__(quant_config)
544504

545-
Returns:
546-
model: fake quantized model
547-
"""
505+
@torch.no_grad()
506+
def prepare(self, model, *args, **kwargs):
507+
"""Prepare a given model to get hidden states and kwargs of first block.
508+
509+
Args:
510+
model: A float torch model.
548511
549-
assert isinstance(model, torch.nn.Module), "only support torch module"
550-
awq = ActAwareWeightQuant(
512+
Returns:
513+
A prepared model.
514+
"""
515+
assert isinstance(model, torch.nn.Module), "AWQ algorithm only supports torch module"
516+
model = replace_forward(model)
517+
return model
518+
519+
@torch.no_grad()
520+
def convert(
521+
self,
551522
model,
552-
example_inputs=example_inputs,
553-
calib_func=calib_func,
554-
dataloader=dataloader,
555-
n_samples=n_samples,
556-
bits=bits,
557-
group_size=group_size,
558-
scheme=scheme,
559-
use_full_range=use_full_range,
560-
weight_config=weight_config,
561-
data_type=data_type,
562-
)
563-
qdq_model = awq.quantize(
564-
use_auto_scale=use_auto_scale,
565-
use_mse_search=use_mse_search,
566-
folding=folding,
567-
return_int=return_int,
568-
)
569-
return qdq_model
523+
bits=4,
524+
group_size=32,
525+
scheme="asym",
526+
example_inputs=None,
527+
dataloader=None,
528+
use_auto_scale=True,
529+
use_mse_search=True,
530+
folding=False,
531+
return_int=False,
532+
use_full_range=False,
533+
data_type="int",
534+
*args,
535+
**kwargs,
536+
):
537+
"""Converts a prepared model to a quantized model.
538+
539+
Args:
540+
model: torch model.
541+
bits: num bits. Defaults to 4.
542+
group_size: how many elements share one scale/zp. Defaults to 32.
543+
scheme: sym or asym. Defaults to "asym".
544+
example_inputs: example_inputs. Defaults to None.
545+
dataloader: datalaoder or example_inputs is required. Defaults to None.
546+
use_auto_scale: whether enable scale for salient weight. Defaults to True.
547+
use_mse_search: whether enable clip for weight by checking mse. Defaults to True.
548+
folding: False will allow insert mul before linear when the scale cannot be absorbed
549+
by last layer, else won't. Defaults to False.
550+
return_int: Choose return fp32 or int32 model. Defaults to False.
551+
use_full_range: Choose sym range whether use -2**(bits-1). Defaults to False.
552+
data_type: data type. Defaults to "int".
553+
554+
Returns:
555+
model: fake quantized model
556+
"""
557+
model = recover_forward(model)
558+
total_block_args = getattr(model, "total_block_args", [])
559+
total_block_kwargs = getattr(model, "total_block_kwargs", [])
560+
delattr(model, "total_block_args")
561+
delattr(model, "total_block_kwargs")
562+
563+
awq = ActAwareWeightQuant(
564+
model,
565+
example_inputs=example_inputs,
566+
dataloader=dataloader,
567+
data_type=data_type,
568+
bits=bits,
569+
group_size=group_size,
570+
scheme=scheme,
571+
use_full_range=use_full_range,
572+
weight_config=self.quant_config,
573+
total_block_args=total_block_args,
574+
total_block_kwargs=total_block_kwargs,
575+
)
576+
qdq_model = awq.quantize(
577+
use_auto_scale=use_auto_scale,
578+
use_mse_search=use_mse_search,
579+
folding=folding,
580+
return_int=return_int,
581+
)
582+
return qdq_model

neural_compressor/torch/algorithms/weight_only/hqq/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@
1414

1515
from .quantizer import HQQuantizer
1616
from .config import HQQModuleConfig, QTensorConfig
17-
from .quant_api import hqq_quantize

neural_compressor/torch/algorithms/weight_only/hqq/quant_api.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)