33import pytest
44import torch
55
6- from neural_compressor .torch .quantization import MXQuantConfig , get_default_mx_config , quantize
6+ from neural_compressor .torch .quantization import MXQuantConfig , convert , get_default_mx_config , prepare
77
88
99def build_simple_torch_model ():
@@ -40,20 +40,35 @@ def teardown_class(self):
4040 def test_mx_quant_default (self ):
4141 fp32_model = copy .deepcopy (self .fp32_model )
4242 quant_config = get_default_mx_config ()
43- q_model = quantize (fp32_model , quant_config = quant_config )
43+ fp32_model = prepare (model = fp32_model , quant_config = quant_config )
44+ q_model = convert (model = fp32_model )
4445 assert q_model is not None , "Quantization failed!"
4546
4647 @pytest .mark .parametrize (
47- "w_dtype, weight_only" ,
48+ "w_dtype, weight_only, round_method, out_dtype " ,
4849 [
49- ("fp4" , True ),
50- ("fp8_e5m2" , False ),
50+ ("fp4" , True , "dither" , "float32" ),
51+ ("fp8_e5m2" , False , "floor" , "bfloat16" ),
52+ ("int8" , False , "even" , "float16" ),
53+ ("int4" , False , "nearest" , "float32" ),
54+ ("int2" , False , "dither" , "bfloat16" ),
55+ ("fp8_e4m3" , False , "floor" , "float16" ),
56+ ("fp6_e3m2" , False , "even" , "float32" ),
57+ ("fp6_e2m3" , False , "nearest" , "bfloat16" ),
58+ ("float16" , False , "dither" , "float16" ),
59+ ("bfloat16" , False , "floor" , "float32" ),
5160 ],
5261 )
53- def test_mx_quant_params (self , w_dtype , weight_only ):
62+ def test_mx_quant_params (self , w_dtype , weight_only , round_method , out_dtype ):
5463 fp32_model = copy .deepcopy (self .fp32_model )
55- quant_config = MXQuantConfig (w_dtype = w_dtype , weight_only = weight_only )
56- q_model = quantize (fp32_model , quant_config = quant_config )
64+ quant_config = MXQuantConfig (
65+ w_dtype = w_dtype ,
66+ weight_only = weight_only ,
67+ round_method = round_method ,
68+ out_dtype = out_dtype ,
69+ )
70+ fp32_model = prepare (model = fp32_model , quant_config = quant_config )
71+ q_model = convert (model = fp32_model )
5772 assert q_model is not None , "Quantization failed!"
5873
5974 def test_mx_quant_accuracy (self ):
@@ -72,8 +87,10 @@ def forward(self, x):
7287 fp32_model = copy .deepcopy (model )
7388 fp32_model .linear .weight = torch .nn .Parameter (torch .tensor ([[0.0 , 1.0 ], [1.0 , 0.0 ]]))
7489 example_inputs = torch .zeros (3 , 2 )
90+
7591 quant_config = MXQuantConfig ()
76- q_model = quantize (fp32_model , quant_config = quant_config )
92+ fp32_model = prepare (model = fp32_model , quant_config = quant_config )
93+ q_model = convert (model = fp32_model )
7794 output1 = fp32_model (example_inputs )
7895 output2 = q_model (example_inputs )
7996 # set a big atol to avoid random issue
0 commit comments