3333
3434from neural_compressor .torch .algorithms import Quantizer
3535from neural_compressor .torch .utils import logger
36+ from neural_compressor .torch .utils .auto_accelerator import auto_detect_accelerator
3637
3738from .utility import (
3839 CpuInfo ,
3940 cfg_to_qconfig ,
4041 dump_model_op_stats ,
42+ generate_xpu_qconfig ,
4143 get_ipex_version ,
4244 get_quantizable_ops_recursively ,
4345 ipex_config_path ,
@@ -56,6 +58,7 @@ def __init__(self, quant_config: OrderedDict = {}):
5658 """
5759 super ().__init__ (quant_config )
5860 self .user_cfg = OrderedDict ()
61+ self .device = auto_detect_accelerator ().current_device ()
5962
6063 def prepare (self , model , example_inputs , inplace = True , * args , ** kwargs ):
6164 """Prepares a given model for quantization.
@@ -70,43 +73,61 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
7073 """
7174 assert example_inputs is not None , "Please provide example_inputs for static quantization."
7275
73- _ , cfgs , op_infos_from_cfgs , output_tensor_id_op_name , _ = get_quantizable_ops_recursively (
74- model , example_inputs
75- )
76- # update json file in ipex_config_path; map ipex op_name to pt op_name
77- self .user_cfg = cfg_to_qconfig (self .quant_config , cfgs , op_infos_from_cfgs , output_tensor_id_op_name )
78- model .eval ()
76+ if self .device == "cpu" :
77+ _ , cfgs , op_infos_from_cfgs , output_tensor_id_op_name , _ = get_quantizable_ops_recursively (
78+ model , example_inputs
79+ )
80+ # update json file in ipex_config_path; map ipex op_name to pt op_name
81+ self .user_cfg = cfg_to_qconfig (self .quant_config , cfgs , op_infos_from_cfgs , output_tensor_id_op_name )
82+ else : # pragma: no cover
83+ model = model .to ("xpu" )
7984
80- use_bf16 = self . quant_config . get ( "use_bf16" , None )
85+ model . eval ( )
8186
8287 # Check save_qconf_summary part is a workaround for IPEX bug.
83- # Sometimes the prepared model from get_op_capablitiy loss this attribute
84- if not hasattr (model , "save_qconf_summary" ) or not hasattr (model , "load_qconf_summary" ):
85- from torch .ao .quantization import MinMaxObserver , PerChannelMinMaxObserver , QConfig
86-
87- if ipex_ver .release >= Version ("2.1" ).release :
88- # HistogramObserver will cause a performance issue.
89- # static_qconfig = ipex.quantization.default_static_qconfig_mapping
90- qconfig = QConfig (
91- activation = MinMaxObserver .with_args (qscheme = torch .per_tensor_affine , dtype = torch .quint8 ),
92- weight = PerChannelMinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_channel_symmetric ),
93- )
94- from torch .ao .quantization import QConfigMapping
95-
96- static_qconfig = QConfigMapping ().set_global (qconfig )
97- else :
98- static_qconfig = QConfig (
99- activation = MinMaxObserver .with_args (qscheme = torch .per_tensor_affine , dtype = torch .quint8 ),
100- weight = PerChannelMinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_channel_symmetric ),
101- )
102- if isinstance (example_inputs , dict ):
103- model = ipex .quantization .prepare (
104- model , static_qconfig , example_kwarg_inputs = example_inputs , inplace = inplace
105- )
88+ # Sometimes the prepared model from get_op_capablitiy loss this attributes
89+ if not hasattr (model , "save_qconf_summary" ) or not hasattr (model , "load_qconf_summary" ): # pragma: no cover
90+ from torch .ao .quantization import HistogramObserver , MinMaxObserver , PerChannelMinMaxObserver , QConfig
91+
92+ if self .device != "cpu" : # pragma: no cover
93+ from torch .quantization .quantize_jit import prepare_jit
94+
95+ with torch .no_grad ():
96+ modelJit = torch .jit .trace (model , example_inputs )
97+ qconfig = generate_xpu_qconfig (self .quant_config )
98+ model = prepare_jit (modelJit , qconfig , inplace )
10699 else :
107- model = ipex .quantization .prepare (model , static_qconfig , example_inputs = example_inputs , inplace = inplace )
100+ if ipex_ver .release >= Version ("2.1" ).release :
101+ # HistogramObserver will cause a performance issue.
102+ # static_qconfig = ipex.quantization.default_static_qconfig_mapping
103+ qconfig = QConfig (
104+ activation = MinMaxObserver .with_args (qscheme = torch .per_tensor_affine , dtype = torch .quint8 ),
105+ weight = PerChannelMinMaxObserver .with_args (
106+ dtype = torch .qint8 , qscheme = torch .per_channel_symmetric
107+ ),
108+ )
109+ from torch .ao .quantization import QConfigMapping
110+
111+ static_qconfig = QConfigMapping ().set_global (qconfig )
112+ else : # pragma: no cover
113+ static_qconfig = QConfig (
114+ activation = MinMaxObserver .with_args (qscheme = torch .per_tensor_affine , dtype = torch .quint8 ),
115+ weight = PerChannelMinMaxObserver .with_args (
116+ dtype = torch .qint8 , qscheme = torch .per_channel_symmetric
117+ ),
118+ )
119+ if isinstance (example_inputs , dict ):
120+ model = ipex .quantization .prepare (
121+ model , static_qconfig , example_kwarg_inputs = example_inputs , inplace = inplace
122+ )
123+ else :
124+ model = ipex .quantization .prepare (
125+ model , static_qconfig , example_inputs = example_inputs , inplace = inplace
126+ )
127+
128+ if self .device == "cpu" :
129+ model .load_qconf_summary (qconf_summary = ipex_config_path )
108130
109- model .load_qconf_summary (qconf_summary = ipex_config_path )
110131 return model
111132
112133 def convert (self , model , example_inputs , inplace = True , * args , ** kwargs ):
@@ -124,18 +145,27 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
124145
125146 from neural_compressor .torch .algorithms .static_quant import save
126147
127- model . save_qconf_summary ( qconf_summary = ipex_config_path )
128- model = _ipex_post_quant_process ( model , example_inputs , use_bf16 , inplace = inplace )
148+ if self . device != "cpu" : # pragma: no cover
149+ from torch . quantization . quantize_jit import convert_jit
129150
130- with open (ipex_config_path , "r" ) as f :
131- model .tune_cfg = json .load (f )
132- model .ipex_config_path = ipex_config_path
151+ model = convert_jit (model , inplace )
152+ simple_inference (model , example_inputs , iterations = 2 )
153+ model .qconfig = self .quant_config ["op" ]
154+ dump_model_op_stats (model .qconfig )
155+ else :
156+ model .save_qconf_summary (qconf_summary = ipex_config_path )
157+ model = _ipex_post_quant_process (model , example_inputs , use_bf16 , inplace = inplace )
133158
134- dump_model_op_stats (self .user_cfg )
159+ with open (ipex_config_path , "r" ) as f :
160+ model .tune_cfg = json .load (f )
161+ model .ipex_config_path = ipex_config_path
162+
163+ dump_model_op_stats (self .user_cfg )
135164
136- logger .info ("Static quantization done." )
137165 model .ori_save = model .save
138166 model .save = MethodType (save , model )
167+
168+ logger .info ("Static quantization done." )
139169 return model
140170
141171
0 commit comments