Skip to content

Commit 02958dd

Browse files
[1/N] Integrate PT2E static PTQ (#1739)
* add x86InductorQuantizer (static) --------- Signed-off-by: yiliu30 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7be355d commit 02958dd

File tree

4 files changed

+224
-0
lines changed

4 files changed

+224
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Note - The `W8A8StaticQuantizer` is aligned with with the pytorch-labs/ao's unified quantization API.
16+
# https://github.com/pytorch-labs/ao/blob/5401df093564825c06691f4c2c10cdcf1a32a40c/torchao/quantization/unified.py#L15-L26
17+
# Some code snippets are taken from the X86InductorQuantizer tutorial.
18+
# https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html
19+
20+
21+
from typing import Any, Dict, Optional, Tuple, Union
22+
23+
import torch
24+
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
25+
from torch._export import capture_pre_autograd_graph
26+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
27+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
28+
from torch.fx.graph_module import GraphModule
29+
30+
from neural_compressor.common.utils import logger
31+
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
32+
33+
34+
class W8A8StaticQuantizer:
35+
36+
@staticmethod
37+
def update_quantizer_based_on_quant_config(quantizer: X86InductorQuantizer, quant_config) -> X86InductorQuantizer:
38+
# TODO: add the logic to update the quantizer based on the quant_config
39+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
40+
return quantizer
41+
42+
@staticmethod
43+
def export_model(
44+
model,
45+
example_inputs: Tuple[Any],
46+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
47+
) -> Optional[GraphModule]:
48+
exported_model = None
49+
try:
50+
with torch.no_grad():
51+
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
52+
# updated to use the official `torch.export` API when that is ready.
53+
cur_version = get_torch_version()
54+
if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover
55+
logger.warning(
56+
(
57+
"`dynamic_shapes` is not supported in the current version(%s) of PyTorch,"
58+
"If you want to use `dynamic_shapes` to export model, "
59+
"please upgrade to 2.3.0 or later."
60+
),
61+
cur_version,
62+
)
63+
exported_model = capture_pre_autograd_graph(model, args=example_inputs)
64+
else: # pragma: no cover
65+
exported_model = capture_pre_autograd_graph( # pylint: disable=E1123
66+
model, args=example_inputs, dynamic_shapes=dynamic_shapes
67+
)
68+
except Exception as e:
69+
logger.error(f"Failed to export the model: {e}")
70+
return exported_model
71+
72+
def prepare(
73+
self, model: torch.nn.Module, quant_config, example_inputs: Tuple[Any], *args: Any, **kwargs: Any
74+
) -> GraphModule:
75+
"""Prepare the model for calibration.
76+
77+
There are two steps in this process:
78+
1) export the eager model into model with Aten IR.
79+
2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
80+
"""
81+
assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}"
82+
# Set the model to eval mode
83+
model = model.eval()
84+
85+
# 1) Capture the FX Graph to be quantized
86+
dynamic_shapes = kwargs.get("dynamic_shapes", None)
87+
exported_model = self.export_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
88+
logger.info("Exported the model to Aten IR successfully.")
89+
if exported_model is None:
90+
return
91+
92+
# 2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
93+
quantizer = X86InductorQuantizer()
94+
quantizer = self.update_quantizer_based_on_quant_config(quantizer, quant_config)
95+
prepared_model = prepare_pt2e(exported_model, quantizer)
96+
return prepared_model
97+
98+
def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
99+
"""Convert the calibrated model into qdq mode."""
100+
fold_quantize = kwargs.get("fold_quantize", False)
101+
converted_model = convert_pt2e(model, fold_quantize=fold_quantize)
102+
logger.warning("Converted the model in qdq mode, please compile it to accelerate inference.")
103+
return converted_model

neural_compressor/torch/utils/environ.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def get_ipex_version():
5353
return None
5454

5555

56+
TORCH_VERSION_2_2_2 = Version("2.2.2")
57+
58+
5659
def get_torch_version():
5760
try:
5861
torch_version = torch.__version__.split("+")[0]
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import unittest
2+
from unittest.mock import patch
3+
4+
import pytest
5+
import torch
6+
7+
from neural_compressor.common.utils import logger
8+
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
9+
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
10+
11+
12+
class TestW8A8StaticQuantizer:
13+
14+
@staticmethod
15+
def get_toy_model():
16+
class Bar(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
21+
x = a / (torch.abs(a) + 1)
22+
if b.sum() < 0:
23+
b = b * -1
24+
return x * b
25+
26+
inp1 = torch.randn(10)
27+
inp2 = torch.randn(10)
28+
example_inputs = (inp1, inp2)
29+
bar = Bar()
30+
return bar, example_inputs
31+
32+
@staticmethod
33+
def build_simple_torch_model_and_example_inputs():
34+
class SimpleModel(torch.nn.Module):
35+
def __init__(self):
36+
super().__init__()
37+
self.fc1 = torch.nn.Linear(10, 20)
38+
self.fc2 = torch.nn.Linear(20, 10)
39+
40+
def forward(self, x: torch.Tensor) -> torch.Tensor:
41+
x = self.fc1(x)
42+
x = torch.nn.functional.relu(x)
43+
x = self.fc2(x)
44+
return x
45+
46+
model = SimpleModel()
47+
example_inputs = (torch.randn(10, 10),)
48+
return model, example_inputs
49+
50+
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
51+
def test_quantizer_on_simple_model(self):
52+
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
53+
quant_config = None
54+
w8a8_static_quantizer = W8A8StaticQuantizer()
55+
# prepare
56+
prepare_model = w8a8_static_quantizer.prepare(model, quant_config, example_inputs=example_inputs)
57+
# calibrate
58+
for i in range(2):
59+
prepare_model(*example_inputs)
60+
# convert
61+
converted_model = w8a8_static_quantizer.convert(prepare_model)
62+
# inference
63+
from torch._inductor import config
64+
65+
config.freezing = True
66+
opt_model = torch.compile(converted_model)
67+
out = opt_model(*example_inputs)
68+
logger.warning("out shape is %s", out.shape)
69+
assert out is not None
70+
71+
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
72+
def test_quantizer_on_llm(self):
73+
from transformers import AutoModelForCausalLM, AutoTokenizer
74+
75+
model_name = "facebook/opt-125m"
76+
model = AutoModelForCausalLM.from_pretrained(model_name)
77+
tokenizer = AutoTokenizer.from_pretrained(model_name)
78+
input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
79+
example_inputs = (input_ids,)
80+
quant_config = None
81+
w8a8_static_quantizer = W8A8StaticQuantizer()
82+
# prepare
83+
prepare_model = w8a8_static_quantizer.prepare(model, quant_config, example_inputs=example_inputs)
84+
# calibrate
85+
for i in range(2):
86+
prepare_model(*example_inputs)
87+
# convert
88+
converted_model = w8a8_static_quantizer.convert(prepare_model)
89+
# inference
90+
from torch._inductor import config
91+
92+
config.freezing = True
93+
opt_model = torch.compile(converted_model)
94+
out = opt_model(*example_inputs)
95+
assert out.logits is not None
96+
97+
@patch("neural_compressor.torch.algorithms.pt2e_quant.core.logger.error")
98+
def test_export_model_failed(self, mock_error):
99+
model, example_inputs = self.get_toy_model()
100+
w8a8_static_quantizer = W8A8StaticQuantizer()
101+
# export model
102+
exported_model = w8a8_static_quantizer.export_model(model, example_inputs=example_inputs)
103+
assert exported_model is None
104+
call_args_list = mock_error.call_args_list
105+
assert any(["Failed to export the model" in msg for msg in [info[0][0] for info in call_args_list]])

0 commit comments

Comments
 (0)