@@ -183,7 +183,7 @@ def quantize(x, scale, zero, maxq):
183183 return scale * (q - zero )
184184
185185
186- class GPTQuantizer (object ):
186+ class RAWGPTQuantizer (object ):
187187 """Main API for GPTQ algorithm.
188188
189189 Please refer to:
@@ -195,15 +195,14 @@ def __init__(
195195 self ,
196196 model ,
197197 weight_config = {},
198- dataloader = None ,
199198 nsamples = 128 ,
200199 use_max_length = True ,
201200 max_seq_length = 2048 ,
202201 device = None ,
203202 export_compressed_model = False ,
204203 use_layer_wise = False ,
205204 model_path = "" ,
206- run_fn = None ,
205+ dataloader = None ,
207206 * args ,
208207 ** kwargs ,
209208 ):
@@ -226,7 +225,6 @@ def __init__(
226225 export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False.
227226 use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
228227 model_path (str): Model path that is used to load state_dict per layer.
229- run_fn: a function to run model inference for collecting input information.
230228 device: cpu or cuda
231229 """
232230 # model
@@ -271,9 +269,7 @@ def __init__(
271269 self .dataloader_original = dataloader
272270 self .dataloader = []
273271 self .nsamples = nsamples
274- self .run_fn = run_fn
275- self .run_args = kwargs .get ("run_args" , None )
276- if run_fn is None :
272+ if dataloader is not None :
277273 self .prepare_dataloader ()
278274
279275 def prepare_dataloader (self ):
@@ -489,7 +485,7 @@ def track_hidden_states(self, data):
489485 return data [0 ]
490486
491487 @torch .no_grad ()
492- def pre_quantization (self ):
488+ def prepare_for_calibration (self ):
493489 """Prepare input calibration data and other attributes which are critical for gptq execution."""
494490 try :
495491 self .cache_key_arguments = {
@@ -532,34 +528,13 @@ def forward(layer, *args, **kwargs):
532528 # Step2: modify the first transformer block's forward function to obtain inputs for calibration
533529 if not self .use_layer_wise :
534530 self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].to (self .device )
535- forward_cache = self .gptq_related_blocks ["transformers" ][0 ].forward
531+ self . forward_cache = self .gptq_related_blocks ["transformers" ][0 ].forward
536532 self .gptq_related_blocks ["transformers" ][0 ].forward = partial (
537533 forward , self .gptq_related_blocks ["transformers" ][0 ]
538534 )
539535
540- # Step3: run forward to obtain calibration datasets
541- logger .info ("Collecting calibration inputs..." )
542- logger .info ("Collecting calibration inputs by running the run_fn provided by user." )
543- if self .run_fn :
544- if self .run_args :
545- self .run_fn (self .model , * self .run_args )
546- accelerator .mark_step ()
547- else :
548- self .run_fn (self .model )
549- accelerator .mark_step ()
550- else :
551- for batch in tqdm (self .dataloader ):
552- if not self .use_layer_wise :
553- batch = move_input_to_device (batch , self .device )
554- try :
555- if isinstance (batch , tuple ) or isinstance (batch , list ):
556- self .model (batch [0 ])
557- elif isinstance (batch , dict ):
558- self .model (** batch )
559- else :
560- self .model (batch )
561- except ValueError :
562- pass
536+ @torch .no_grad ()
537+ def remove_prepare_for_calibration (self ):
563538 # output inp data shape
564539 logger .info ("All calibration data's shape =>" )
565540 # check all hidden_states shape
@@ -571,7 +546,7 @@ def forward(layer, *args, **kwargs):
571546 logger .info ("Done." )
572547
573548 # Step 4: restore original forward function, relocate layers back to cpu.
574- self .gptq_related_blocks ["transformers" ][0 ].forward = forward_cache
549+ self .gptq_related_blocks ["transformers" ][0 ].forward = self . forward_cache
575550 if not self .use_layer_wise :
576551 self .gptq_related_blocks ["transformers" ][0 ] = self .gptq_related_blocks ["transformers" ][0 ].cpu ()
577552 for embedding_name , embedding_layer in self .gptq_related_blocks ["embeddings" ].items ():
@@ -606,7 +581,6 @@ def execute_quantization(self, means=None, stds=None):
606581 # Step1: prepare quantization (calibration datasets)
607582
608583 logger .info ("Begin ====>" )
609- self .pre_quantization ()
610584 model_path = self .model_path
611585
612586 # Step2: run gptq quantization in a transformer block-wise manner.
@@ -1144,41 +1118,57 @@ def ready(self):
11441118 return torch .all (self .scale != 0 )
11451119
11461120
1147- def gptq_quantize (
1148- model ,
1149- weight_config = {},
1150- dataloader = None ,
1151- nsamples = 128 ,
1152- max_seq_length = 2048 ,
1153- use_max_length = True ,
1154- device = None ,
1155- export_compressed_model = False ,
1156- use_layer_wise = False ,
1157- model_path = None ,
1158- run_fn = None ,
1159- run_args = None ,
1160- ):
1161- """Run weight-only quantization with."""
1162- # TODO: unify weight_config keys, add docstring, and support default config
1163- assert isinstance (model , torch .nn .Module ), "only support torch module"
1164- if use_layer_wise :
1165- assert model_path is not None , "model_path should not be None when use layer wise mode"
1166- from .gptq import GPTQuantizer
1167-
1168- gptq_quantizer = GPTQuantizer (
1121+ from neural_compressor .torch .algorithms import Quantizer as INCQuantizer
1122+
1123+
1124+ class GPTQuantizer (INCQuantizer ):
1125+ def __init__ (self , quant_config = {}):
1126+ """Init a RTNQuantizer object.
1127+
1128+ Args:
1129+ quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
1130+ """
1131+ super ().__init__ (quant_config )
1132+
1133+ @torch .no_grad ()
1134+ def prepare (
1135+ self ,
11691136 model ,
1170- weight_config ,
1171- dataloader ,
1172- nsamples ,
1173- use_max_length ,
1174- max_seq_length ,
1175- device ,
1176- export_compressed_model = export_compressed_model ,
1177- use_layer_wise = use_layer_wise ,
1178- model_path = model_path ,
1179- run_fn = run_fn ,
1180- run_args = run_args ,
1181- )
1182- fp32_modified_model , gptq_config = gptq_quantizer .execute_quantization ()
1183- logger .info ("GPTQ quantizing done." )
1184- return fp32_modified_model , gptq_config
1137+ nsamples = 128 ,
1138+ max_seq_length = 2048 ,
1139+ use_max_length = True ,
1140+ device = None ,
1141+ export_compressed_model = False ,
1142+ use_layer_wise = False ,
1143+ model_path = None ,
1144+ * args ,
1145+ ** kwargs ,
1146+ ):
1147+ """Run weight-only quantization with."""
1148+ # TODO: unify weight_config keys, add docstring, and support default config
1149+ assert isinstance (model , torch .nn .Module ), "only support torch module"
1150+ if use_layer_wise :
1151+ assert model_path is not None , "model_path should not be None when use layer wise mode"
1152+
1153+ self .gptq_quantizer = RAWGPTQuantizer (
1154+ model ,
1155+ weight_config = self .quant_config ,
1156+ nsamples = nsamples ,
1157+ use_max_length = use_max_length ,
1158+ max_seq_length = max_seq_length ,
1159+ device = device ,
1160+ export_compressed_model = export_compressed_model ,
1161+ use_layer_wise = use_layer_wise ,
1162+ model_path = model_path ,
1163+ )
1164+ self .gptq_quantizer .prepare_for_calibration ()
1165+ return self .gptq_quantizer .model
1166+
1167+ @torch .no_grad ()
1168+ def convert (self , model , * args , ** kwargs ):
1169+ self .gptq_quantizer .model = model
1170+ self .gptq_quantizer .remove_prepare_for_calibration ()
1171+ q_model , gptq_config = self .gptq_quantizer .execute_quantization ()
1172+ q_model .gptq_config = gptq_config
1173+ logger .info ("GPTQ quantizing done." )
1174+ return q_model
0 commit comments