77
88from neural_compressor .common .utils import logger
99from neural_compressor .torch .algorithms .weight_only .gptq import move_input_to_device
10- from neural_compressor .torch .quantization import GPTQConfig , get_default_rtn_config , quantize
10+ from neural_compressor .torch .quantization import (
11+ AWQConfig ,
12+ GPTQConfig ,
13+ get_default_awq_config ,
14+ get_default_rtn_config ,
15+ get_default_teq_config ,
16+ quantize ,
17+ )
18+
19+
20+ def get_gpt_j ():
21+ tiny_gptj = transformers .AutoModelForCausalLM .from_pretrained (
22+ "hf-internal-testing/tiny-random-GPTJForCausalLM" ,
23+ torchscript = True ,
24+ )
25+ return tiny_gptj
1126
1227
1328class GPTQDataloaderPreprocessor :
@@ -213,9 +228,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
213228 pass
214229 return
215230
216- user_model = transformers .AutoModelForCausalLM .from_pretrained (
217- "hf-internal-testing/tiny-random-GPTJForCausalLM" ,
218- )
231+ user_model = get_gpt_j ()
219232
220233 user_model = quantize (
221234 model = user_model , quant_config = quant_config , run_fn = run_fn_for_gptq , run_args = dataloader_for_calibration
@@ -228,9 +241,7 @@ class TestRTNQuant:
228241
229242 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "test requires a GPU" )
230243 def test_rtn (self ):
231- self .tiny_gptj = transformers .AutoModelForCausalLM .from_pretrained (
232- "hf-internal-testing/tiny-random-GPTJForCausalLM" ,
233- )
244+ self .tiny_gptj = get_gpt_j ()
234245 self .example_inputs = torch .tensor ([[10 , 20 , 30 , 40 , 50 , 60 ]], dtype = torch .long )
235246 model = self .tiny_gptj
236247 # record label for comparison
@@ -239,3 +250,129 @@ def test_rtn(self):
239250 quant_config = get_default_rtn_config ()
240251 q_model = quantize (model , quant_config )
241252 assert "cuda" in str (q_model .device ), f"Expect qmodel device is cuda, got { q_model .device } "
253+
254+
255+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "test requires a GPU" )
256+ class TestAWQOnCuda :
257+
258+ def test_awq (self ):
259+ self .lm_input = torch .ones ([1 , 10 ], dtype = torch .long )
260+ self .gptj = get_gpt_j ()
261+ example_inputs = torch .ones ([1 , 10 ], dtype = torch .long )
262+
263+ def calib_func (model ):
264+ for i in range (2 ):
265+ model (self .lm_input .to (model .device ))
266+
267+ quant_config = get_default_awq_config ()
268+ logger .info ("Test quantization with config" , quant_config )
269+ q_model = quantize (
270+ model = self .gptj , quant_config = quant_config , example_inputs = self .lm_input , run_fn = calib_func , inplace = False
271+ )
272+ out2 = q_model (example_inputs .to (q_model .device ))
273+ assert "cuda" in str (q_model .device ), f"Expect qmodel device is cuda, got { q_model .device } "
274+ assert "cuda" in str (out2 [0 ].device ), f"Expect out2 device is cuda, got { out2 .device } "
275+
276+
277+ def generate_random_corpus (nsamples = 32 ):
278+ meta_data = []
279+ for _ in range (nsamples ):
280+ inp = torch .ones ([1 , 512 ], dtype = torch .long )
281+ tar = torch .ones ([1 , 512 ], dtype = torch .long )
282+ meta_data .append ((inp , tar ))
283+ return meta_data
284+
285+
286+ def train (
287+ model ,
288+ train_steps = 1000 ,
289+ lr = 1e-3 ,
290+ warmup_ratio = 0.05 ,
291+ gradient_accumulation_steps = 1 ,
292+ logging_steps = 10 ,
293+ betas = [0.9 , 0.9 ],
294+ weight_decay = 0 ,
295+ lr_scheduler_type = "linear" ,
296+ ):
297+ """Train function."""
298+ trained_alphas_list = [torch .ones ([128 ], requires_grad = True )]
299+ optimizer = torch .optim .Adam (trained_alphas_list , lr = lr , weight_decay = weight_decay , betas = betas )
300+
301+ lr_scheduler = transformers .get_scheduler ( # pylint: disable=E1111
302+ name = lr_scheduler_type ,
303+ optimizer = optimizer ,
304+ num_warmup_steps = int (train_steps * warmup_ratio ) // gradient_accumulation_steps ,
305+ num_training_steps = train_steps // gradient_accumulation_steps ,
306+ )
307+
308+ logger .info ("start training" )
309+ model .train ()
310+ global_steps = 0
311+ dataloader = generate_random_corpus ()
312+ while global_steps <= train_steps :
313+ for inputs in dataloader :
314+ if isinstance (inputs , torch .Tensor ):
315+ input_id = inputs
316+ elif isinstance (inputs , dict ):
317+ input_id = inputs ["input_ids" ]
318+ else :
319+ input_id = inputs [0 ]
320+ output = model (input_id .to (model .device ), labels = input_id .to (model .device ))
321+ loss = output [0 ] / gradient_accumulation_steps
322+ loss .backward ()
323+ global_steps += 1
324+
325+ if global_steps % logging_steps == 0 :
326+ logger .info ("steps: {}, loss: {}" .format (global_steps , loss .detach ().cpu ().item ()))
327+
328+ if global_steps % gradient_accumulation_steps == 0 :
329+ optimizer .step ()
330+ optimizer .zero_grad ()
331+ lr_scheduler .step ()
332+
333+ if global_steps >= train_steps : # pragma: no cover
334+ break
335+
336+ logger .info ("finish training" )
337+ model .eval ()
338+ return None
339+
340+
341+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "test requires a GPU" )
342+ class TestTEQOnCuda :
343+
344+ def test_teq (self ):
345+ quant_config = {
346+ "teq" : {
347+ "global" : {
348+ "dtype" : "fp32" ,
349+ },
350+ "local" : {
351+ "transformer.h.0.mlp.fc_in" : {
352+ "dtype" : "int" ,
353+ "bits" : 8 ,
354+ "group_size" : - 1 ,
355+ "use_sym" : True ,
356+ "folding" : True ,
357+ "absorb_to_layer" : {"transformer.h.0.mlp.fc_in" : ["transformer.h.0.mlp.fc_out" ]},
358+ },
359+ "transformer.h.0.mlp.fc_out" : {
360+ "dtype" : "int" ,
361+ "bits" : 4 ,
362+ "group_size" : 32 ,
363+ "use_sym" : False ,
364+ "folding" : True ,
365+ "absorb_to_layer" : {"transformer.h.0.mlp.fc_in" : ["transformer.h.0.mlp.fc_out" ]},
366+ },
367+ },
368+ }
369+ }
370+ example_inputs = torch .ones ([1 , 512 ], dtype = torch .long )
371+ test_input = torch .ones ([1 , 512 ], dtype = torch .long )
372+ model = get_gpt_j ()
373+
374+ qdq_model = quantize (model = model , quant_config = quant_config , run_fn = train , example_inputs = example_inputs )
375+ assert isinstance (qdq_model , torch .nn .Module ), "Expect qdq_model is a torch module"
376+ out2 = qdq_model (test_input .to (qdq_model .device ))
377+ assert "cuda" in str (qdq_model .device ), f"Expect qmodel device is cuda, got { qdq_model .device } "
378+ assert "cuda" in str (out2 [0 ].device ), f"Expect out2 device is cuda, got { out2 .device } "
0 commit comments