1515# Copied from neural_compressor/adaptor/torch_utils/awq.py
1616
1717import copy
18+ from collections import OrderedDict
1819
1920import torch
2021
22+ from neural_compressor .torch .algorithms import Quantizer
2123from neural_compressor .torch .utils import get_device , logger
2224
2325from .modules import MulLinear
2628 get_absorb_layers ,
2729 get_block_prefix ,
2830 get_example_input ,
29- get_hidden_states ,
3031 get_module_input_output ,
31- model_forward ,
32+ recover_forward ,
33+ replace_forward ,
3234 set_module ,
3335)
3436
35- __all__ = ["awq_quantize " ]
37+ __all__ = ["AWQQuantizer " ]
3638
3739
3840def _get_absorb_per_block (model , example_inputs , folding = False , weight_config = {}):
@@ -113,15 +115,15 @@ def __init__(
113115 self ,
114116 model ,
115117 example_inputs = None ,
116- calib_func = None ,
117118 dataloader = None ,
118- n_samples = 128 ,
119119 data_type = "int" ,
120120 bits = 4 ,
121121 group_size = 32 ,
122122 scheme = "asym" ,
123123 use_full_range = False ,
124124 weight_config = {},
125+ total_block_args = [],
126+ total_block_kwargs = [],
125127 ):
126128
127129 self .example_inputs = example_inputs
@@ -130,11 +132,9 @@ def __init__(
130132 assert dataloader is not None , "datalaoder or example_inputs is required."
131133 self .example_inputs = get_example_input (dataloader )
132134 self ._move_model_and_data_to_device ()
133- # Step 1: get hidden states and kwargs of first block.
134- self .total_block_args , self .total_block_kwargs = get_hidden_states (
135- model , dataloader = dataloader , n_samples = n_samples , calib_func = calib_func
136- )
137- # Step 2: get block list and block prefix, number
135+ self .total_block_args = total_block_args
136+ self .total_block_kwargs = total_block_kwargs
137+ # get block list and block prefix, number
138138 self .block_prefix , self .block_num = get_block_prefix (model )
139139 self .block_list = fetch_module (model , self .block_prefix )
140140 self .data_type = data_type
@@ -429,14 +429,15 @@ def apply_quantize_with_clip(self, return_int=False):
429429 """
430430 # apply quantization and clip
431431 logger .info ("Quantizing the AWQ optimized fp32 model" )
432- from .rtn import rtn_quantize
432+ from .rtn import RTNQuantizer
433+
434+ rtn_quantizer = RTNQuantizer (quant_config = self .weight_config )
433435
434- self .model = rtn_quantize (
436+ self .model = rtn_quantizer . quantize (
435437 self .model ,
436- num_bits = self .bits ,
438+ bits = self .bits ,
437439 group_size = self .group_size ,
438440 scheme = self .scheme ,
439- weight_config = self .weight_config ,
440441 return_int = return_int ,
441442 use_full_range = self .use_full_range ,
442443 )
@@ -492,78 +493,90 @@ def module_inference(self, model, inputs):
492493 return total_out
493494
494495
495- @torch .no_grad ()
496- def awq_quantize (
497- model ,
498- bits = 4 ,
499- group_size = 32 ,
500- scheme = "asym" ,
501- weight_config = {},
502- example_inputs = None ,
503- dataloader = None ,
504- n_samples = 128 ,
505- calib_func = None ,
506- use_auto_scale = True ,
507- use_mse_search = True ,
508- folding = False ,
509- return_int = False ,
510- use_full_range = False ,
511- data_type = "int" ,
512- ):
513- """Quant the model with Activation-aware Weight quantization(AWQ) method.
496+ class AWQQuantizer (Quantizer ):
497+ def __init__ (self , quant_config : OrderedDict = {}):
498+ """Init an AWQQuantizer object.
514499
515- Args:
516- model (torch.nn.Module): torch model.
517- example_inputs: example_inputs.
518- weight_config (dict, optional): contains all info required by AWQ. Defaults to {}.
519- For example,
520- weight_config={
521- 'fc2':
522- {
523- # 'absorb_layer': 'fc1',
524- 'bits': 4,
525- 'group_size': 32,
526- 'scheme': 'sym'
527- }
528- }
529- absorb_dict (dict, optional): contains all absorb info required by AWQ.. Defaults to {}.
530- For example,
531- absorb_dict = {
532- # 'absorb_layer': absorbed_layer
533- 'fc1': ['fc1', 'fc2', 'fc3']
534- } # in this case, fc2 and fc3 need to share the same scale. fc1 is self absorbed.
535- # self absorb module will replace with MulLinear, which contains torch.mul and module.
536- n_samples: calibration sample number.
537- use_auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True.
538- use_mse_search (bool, optional): whether enable clip for weight by checking mse. Defaults to True.
539- calib_func: a custom inference function to replace dataloader and iters.
540- n_blocks: split model into block number to avoid OOM.
541- return_int (bool, optional): Choose return fp32 or int32 model.
542- Defaults to False.
543- use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
500+ Args:
501+ quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
502+ """
503+ super ().__init__ (quant_config )
544504
545- Returns:
546- model: fake quantized model
547- """
505+ @torch .no_grad ()
506+ def prepare (self , model , * args , ** kwargs ):
507+ """Prepare a given model to get hidden states and kwargs of first block.
508+
509+ Args:
510+ model: A float torch model.
548511
549- assert isinstance (model , torch .nn .Module ), "only support torch module"
550- awq = ActAwareWeightQuant (
512+ Returns:
513+ A prepared model.
514+ """
515+ assert isinstance (model , torch .nn .Module ), "AWQ algorithm only supports torch module"
516+ model = replace_forward (model )
517+ return model
518+
519+ @torch .no_grad ()
520+ def convert (
521+ self ,
551522 model ,
552- example_inputs = example_inputs ,
553- calib_func = calib_func ,
554- dataloader = dataloader ,
555- n_samples = n_samples ,
556- bits = bits ,
557- group_size = group_size ,
558- scheme = scheme ,
559- use_full_range = use_full_range ,
560- weight_config = weight_config ,
561- data_type = data_type ,
562- )
563- qdq_model = awq .quantize (
564- use_auto_scale = use_auto_scale ,
565- use_mse_search = use_mse_search ,
566- folding = folding ,
567- return_int = return_int ,
568- )
569- return qdq_model
523+ bits = 4 ,
524+ group_size = 32 ,
525+ scheme = "asym" ,
526+ example_inputs = None ,
527+ dataloader = None ,
528+ use_auto_scale = True ,
529+ use_mse_search = True ,
530+ folding = False ,
531+ return_int = False ,
532+ use_full_range = False ,
533+ data_type = "int" ,
534+ * args ,
535+ ** kwargs ,
536+ ):
537+ """Converts a prepared model to a quantized model.
538+
539+ Args:
540+ model: torch model.
541+ bits: num bits. Defaults to 4.
542+ group_size: how many elements share one scale/zp. Defaults to 32.
543+ scheme: sym or asym. Defaults to "asym".
544+ example_inputs: example_inputs. Defaults to None.
545+ dataloader: datalaoder or example_inputs is required. Defaults to None.
546+ use_auto_scale: whether enable scale for salient weight. Defaults to True.
547+ use_mse_search: whether enable clip for weight by checking mse. Defaults to True.
548+ folding: False will allow insert mul before linear when the scale cannot be absorbed
549+ by last layer, else won't. Defaults to False.
550+ return_int: Choose return fp32 or int32 model. Defaults to False.
551+ use_full_range: Choose sym range whether use -2**(bits-1). Defaults to False.
552+ data_type: data type. Defaults to "int".
553+
554+ Returns:
555+ model: fake quantized model
556+ """
557+ model = recover_forward (model )
558+ total_block_args = getattr (model , "total_block_args" , [])
559+ total_block_kwargs = getattr (model , "total_block_kwargs" , [])
560+ delattr (model , "total_block_args" )
561+ delattr (model , "total_block_kwargs" )
562+
563+ awq = ActAwareWeightQuant (
564+ model ,
565+ example_inputs = example_inputs ,
566+ dataloader = dataloader ,
567+ data_type = data_type ,
568+ bits = bits ,
569+ group_size = group_size ,
570+ scheme = scheme ,
571+ use_full_range = use_full_range ,
572+ weight_config = self .quant_config ,
573+ total_block_args = total_block_args ,
574+ total_block_kwargs = total_block_kwargs ,
575+ )
576+ qdq_model = awq .quantize (
577+ use_auto_scale = use_auto_scale ,
578+ use_mse_search = use_mse_search ,
579+ folding = folding ,
580+ return_int = return_int ,
581+ )
582+ return qdq_model
0 commit comments