1717 prepare ,
1818 quantize ,
1919)
20- from neural_compressor .torch .utils import TORCH_VERSION_2_2_2 , get_torch_version
20+ from neural_compressor .torch .utils import GT_TORCH_VERSION_2_3_2 , TORCH_VERSION_2_2_2 , get_torch_version
2121
2222torch .manual_seed (0 )
2323
@@ -119,6 +119,42 @@ def calib_fn(model):
119119 logger .warning ("out shape is %s" , out .shape )
120120 assert out is not None
121121
122+ @pytest .mark .skipif (not GT_TORCH_VERSION_2_3_2 , reason = "Requires torch>=2.3.2" )
123+ def test_quantize_simple_model_with_set_local (self , force_not_import_ipex ):
124+ model , example_inputs = self .build_simple_torch_model_and_example_inputs ()
125+ float_model_output = model (* example_inputs )
126+ quant_config = None
127+
128+ def calib_fn (model ):
129+ for i in range (4 ):
130+ model (* example_inputs )
131+
132+ quant_config = get_default_static_config ()
133+ quant_config .set_local ("fc1" , StaticQuantConfig (w_dtype = "fp32" , act_dtype = "fp32" ))
134+ q_model = quantize (model = model , quant_config = quant_config , run_fn = calib_fn )
135+
136+ # check the half node
137+ expected_node_occurrence = {
138+ # Only quantize the `fc2`
139+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
140+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
141+ }
142+ expected_node_occurrence = {
143+ torch_test_quant_common .NodeSpec .call_function (k ): v for k , v in expected_node_occurrence .items ()
144+ }
145+ node_in_graph = self .get_node_in_graph (q_model )
146+ for node , cnt in expected_node_occurrence .items ():
147+ assert node_in_graph .get (node , 0 ) == cnt , f"Node { node } should occur { cnt } times, but { node_in_graph [node ]} "
148+
149+ from torch ._inductor import config
150+
151+ config .freezing = True
152+ q_model_out = q_model (* example_inputs )
153+ assert torch .allclose (float_model_output , q_model_out , atol = 1e-2 ), "Quantization failed!"
154+ opt_model = torch .compile (q_model )
155+ out = opt_model (* example_inputs )
156+ assert out is not None
157+
122158 @pytest .mark .skipif (get_torch_version () <= TORCH_VERSION_2_2_2 , reason = "Requires torch>=2.3.0" )
123159 @pytest .mark .parametrize ("is_dynamic" , [False , True ])
124160 def test_prepare_and_convert_on_simple_model (self , is_dynamic , force_not_import_ipex ):
@@ -193,9 +229,9 @@ def get_node_in_graph(graph_module):
193229 nodes_in_graph [n ] += 1
194230 else :
195231 nodes_in_graph [n ] = 1
196- return
232+ return nodes_in_graph
197233
198- @pytest .mark .skipif (get_torch_version () <= TORCH_VERSION_2_2_2 , reason = "Requires torch>=2.3.0" )
234+ @pytest .mark .skipif (not GT_TORCH_VERSION_2_3_2 , reason = "Requires torch>=2.3.0" )
199235 def test_mixed_fp16_and_int8 (self , force_not_import_ipex ):
200236 model , example_inputs = self .build_model_include_conv_and_linear ()
201237 model = export (model , example_inputs = example_inputs )
@@ -221,9 +257,7 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex):
221257 }
222258 node_in_graph = self .get_node_in_graph (converted_model )
223259 for node , cnt in expected_node_occurrence .items ():
224- assert (
225- expected_node_occurrence .get (node , 0 ) == cnt
226- ), f"Node { node } should occur { cnt } times, but { node_in_graph [node ]} "
260+ assert node_in_graph .get (node , 0 ) == cnt , f"Node { node } should occur { cnt } times, but { node_in_graph [node ]} "
227261
228262 # inference
229263 from torch ._inductor import config
0 commit comments