|
| 1 | +# Copyright (c) 2024 Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# Note - The `W8A8StaticQuantizer` is aligned with with the pytorch-labs/ao's unified quantization API. |
| 16 | +# https://github.com/pytorch-labs/ao/blob/5401df093564825c06691f4c2c10cdcf1a32a40c/torchao/quantization/unified.py#L15-L26 |
| 17 | +# Some code snippets are taken from the X86InductorQuantizer tutorial. |
| 18 | +# https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html |
| 19 | + |
| 20 | + |
| 21 | +from typing import Any, Dict, Optional, Tuple, Union |
| 22 | + |
| 23 | +import torch |
| 24 | +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq |
| 25 | +from torch._export import capture_pre_autograd_graph |
| 26 | +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 27 | +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer |
| 28 | +from torch.fx.graph_module import GraphModule |
| 29 | + |
| 30 | +from neural_compressor.common.utils import logger |
| 31 | +from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version |
| 32 | + |
| 33 | + |
| 34 | +class W8A8StaticQuantizer: |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def update_quantizer_based_on_quant_config(quantizer: X86InductorQuantizer, quant_config) -> X86InductorQuantizer: |
| 38 | + # TODO: add the logic to update the quantizer based on the quant_config |
| 39 | + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) |
| 40 | + return quantizer |
| 41 | + |
| 42 | + @staticmethod |
| 43 | + def export_model( |
| 44 | + model, |
| 45 | + example_inputs: Tuple[Any], |
| 46 | + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, |
| 47 | + ) -> Optional[GraphModule]: |
| 48 | + exported_model = None |
| 49 | + try: |
| 50 | + with torch.no_grad(): |
| 51 | + # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be |
| 52 | + # updated to use the official `torch.export` API when that is ready. |
| 53 | + cur_version = get_torch_version() |
| 54 | + if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover |
| 55 | + logger.warning( |
| 56 | + ( |
| 57 | + "`dynamic_shapes` is not supported in the current version(%s) of PyTorch," |
| 58 | + "If you want to use `dynamic_shapes` to export model, " |
| 59 | + "please upgrade to 2.3.0 or later." |
| 60 | + ), |
| 61 | + cur_version, |
| 62 | + ) |
| 63 | + exported_model = capture_pre_autograd_graph(model, args=example_inputs) |
| 64 | + else: # pragma: no cover |
| 65 | + exported_model = capture_pre_autograd_graph( # pylint: disable=E1123 |
| 66 | + model, args=example_inputs, dynamic_shapes=dynamic_shapes |
| 67 | + ) |
| 68 | + except Exception as e: |
| 69 | + logger.error(f"Failed to export the model: {e}") |
| 70 | + return exported_model |
| 71 | + |
| 72 | + def prepare( |
| 73 | + self, model: torch.nn.Module, quant_config, example_inputs: Tuple[Any], *args: Any, **kwargs: Any |
| 74 | + ) -> GraphModule: |
| 75 | + """Prepare the model for calibration. |
| 76 | +
|
| 77 | + There are two steps in this process: |
| 78 | + 1) export the eager model into model with Aten IR. |
| 79 | + 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly. |
| 80 | + """ |
| 81 | + assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}" |
| 82 | + # Set the model to eval mode |
| 83 | + model = model.eval() |
| 84 | + |
| 85 | + # 1) Capture the FX Graph to be quantized |
| 86 | + dynamic_shapes = kwargs.get("dynamic_shapes", None) |
| 87 | + exported_model = self.export_model(model, example_inputs, dynamic_shapes=dynamic_shapes) |
| 88 | + logger.info("Exported the model to Aten IR successfully.") |
| 89 | + if exported_model is None: |
| 90 | + return |
| 91 | + |
| 92 | + # 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly. |
| 93 | + quantizer = X86InductorQuantizer() |
| 94 | + quantizer = self.update_quantizer_based_on_quant_config(quantizer, quant_config) |
| 95 | + prepared_model = prepare_pt2e(exported_model, quantizer) |
| 96 | + return prepared_model |
| 97 | + |
| 98 | + def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule: |
| 99 | + """Convert the calibrated model into qdq mode.""" |
| 100 | + fold_quantize = kwargs.get("fold_quantize", False) |
| 101 | + converted_model = convert_pt2e(model, fold_quantize=fold_quantize) |
| 102 | + logger.warning("Converted the model in qdq mode, please compile it to accelerate inference.") |
| 103 | + return converted_model |
0 commit comments