Skip to content

Commit fa961e1

Browse files
authored
Support mixed INT8 + FP16 in one model (#1798)
Signed-off-by: yiliu30 <[email protected]>
1 parent 4a24a6a commit fa961e1

File tree

6 files changed

+347
-12
lines changed

6 files changed

+347
-12
lines changed

neural_compressor/torch/algorithms/pt2e_quant/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from neural_compressor.common.utils import logger
2727
from neural_compressor.torch.algorithms.base_algorithm import Quantizer
28+
from neural_compressor.torch.algorithms.pt2e_quant import half_precision_rewriter as hp_rewriter
2829
from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config
2930

3031

@@ -61,4 +62,11 @@ def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
6162
fold_quantize = kwargs.get("fold_quantize", False)
6263
converted_model = convert_pt2e(model, fold_quantize=fold_quantize)
6364
logger.warning("Converted the model in qdq mode, please compile it to accelerate inference.")
65+
if self.quant_config:
66+
self.half_precision_transformation(converted_model, self.quant_config)
6467
return converted_model
68+
69+
def half_precision_transformation(self, model, config):
70+
half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config)
71+
logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set))
72+
hp_rewriter.transformation(model, half_precision_node_set)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from functools import partial
17+
from typing import Any, Callable, Dict, List, Tuple
18+
19+
import torch
20+
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
21+
import torch.ao.quantization.quantizer.xnnpack_quantizer as xpq
22+
from torch.fx import subgraph_rewriter
23+
from torch.fx.experimental.proxy_tensor import make_fx
24+
from torch.fx.subgraph_rewriter import Match
25+
from typing_extensions import TypeAlias
26+
27+
from neural_compressor.common import utils
28+
29+
# =============================================================================
30+
# Search and replace patterns
31+
# =============================================================================
32+
TorchFuncType: TypeAlias = Callable[..., Any]
33+
34+
35+
@dataclass
36+
class PatternPair:
37+
fn: TorchFuncType
38+
search_pattern: torch.fx.GraphModule
39+
replace_pattern: torch.fx.GraphModule
40+
41+
42+
# key: torch func
43+
# value: the tuple of args
44+
FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, Tuple[torch.Tensor, ...]]
45+
46+
47+
# Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16
48+
# TODO: complete the mapping
49+
FN_ARGS_MAPPING: FuncArgsMappingType = {
50+
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias
51+
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias
52+
}
53+
# TODO: complete the mapping
54+
FN_ATEN_OPS_MAPPING = {
55+
torch.nn.functional.linear: torch.ops.aten.linear.default,
56+
}
57+
58+
SUPPORTED_OPERATORS = FN_ATEN_OPS_MAPPING.values()
59+
60+
61+
PatternRegistryType: TypeAlias = Dict[TorchFuncType, PatternPair]
62+
HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {torch.float16: {}, torch.bfloat16: {}}
63+
64+
# FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16]
65+
# BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16]
66+
67+
68+
def pattern_factory(fn: TorchFuncType, fn_arg: Tuple[torch.Tensor, ...], target_dtype: torch.dtype = torch.float16):
69+
"""Create a search, replace pattern and filter functions for a given torch function and its arguments."""
70+
assert target_dtype in [
71+
torch.float16,
72+
torch.bfloat16,
73+
], f"target_dtype should either be `torch.float16` or `torch.bfloat16`, but got {target_dtype}"
74+
75+
def replace_fn_wrapper(fn_args, fn):
76+
converted_args = [arg.to(target_dtype) for arg in fn_args]
77+
target_dtype_out = fn(*converted_args)
78+
return target_dtype_out.float()
79+
80+
replace_fn = partial(replace_fn_wrapper, fn=fn)
81+
82+
search_pattern_gm = make_fx(fn, pre_dispatch=True)(*fn_arg)
83+
# TODO: double-check `*fn_args` or `fn_args`
84+
replace_pattern_gm = make_fx(replace_fn, pre_dispatch=True)(fn_arg)
85+
86+
pattern_pair = PatternPair(fn, search_pattern_gm, replace_pattern_gm)
87+
88+
return pattern_pair
89+
90+
91+
def _register_pattern_pair(dtype: torch.dtype) -> None:
92+
for fn, fn_args in FN_ARGS_MAPPING.items():
93+
pattern_pair = pattern_factory(fn, fn_args)
94+
HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair
95+
utils.logger.info(
96+
f"Registered {len(HALF_PRECISION_PATTERN_REGISTRY[dtype])} search and replace patterns for {dtype}."
97+
)
98+
99+
100+
_register_pattern_pair(torch.float16)
101+
102+
103+
def get_filter_fn(node_list, fn):
104+
target_op = FN_ATEN_OPS_MAPPING[fn]
105+
106+
def is_target_node_in_candidate_list(match, original_graph, pattern_graph):
107+
"""Filter the node with target operator in match and check if it is in `node_list`."""
108+
target_node = None
109+
for node in pattern_graph.nodes:
110+
if node.target == target_op:
111+
target_node = node
112+
break
113+
if target_node is None:
114+
return False
115+
matched_node = match.nodes_map[target_node]
116+
return matched_node in node_list
117+
118+
return is_target_node_in_candidate_list
119+
120+
121+
def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPair, node_list):
122+
filter_fn = get_filter_fn(node_list, pattern_pair.fn)
123+
match_and_replacements = subgraph_rewriter.replace_pattern_with_filters(
124+
gm=gm,
125+
pattern=pattern_pair.search_pattern,
126+
replacement=pattern_pair.replace_pattern,
127+
match_filters=[filter_fn],
128+
)
129+
utils.logger.info(f"Found {len(match_and_replacements)} matches.")
130+
131+
match_list = [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
132+
return match_list
133+
134+
135+
def get_unquantized_node_set(gm: torch.fx.GraphModule):
136+
unquantized_node_set = set()
137+
for node in gm.graph.nodes:
138+
if meta := getattr(node, "meta"):
139+
if quantization_annotation := meta.get(xiq.QUANT_ANNOTATION_KEY):
140+
if quantization_annotation._annotated:
141+
continue
142+
unquantized_node_set.add(node)
143+
return unquantized_node_set
144+
145+
146+
def transformation(gm: torch.fx.GraphModule, node_candidate_list: List[str], target_dtype: torch.dtype = torch.float16):
147+
"""Convert the nodes in `node_candidate_list` to `target_dtype` if possible."""
148+
for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values():
149+
apply_single_pattern_pair(gm, pattern_pair, node_candidate_list)
150+
utils.logger.info("Half precision conversion is done:")
151+
gm.print_readable(True)
152+
153+
154+
# =============================================================================
155+
# Utils to parse the node candidate set for half precision conversion
156+
# =============================================================================
157+
158+
159+
def _parse_node_candidate_set_from_user_config(config, gm):
160+
"""Parse the node candidate set from user config."""
161+
op_type_configs, op_name_configs = config._get_op_name_op_type_config()
162+
op_type_filters = []
163+
op_name_filters = []
164+
for op_type_name, config in op_type_configs.items():
165+
op_type = getattr(torch.nn, op_type_name)
166+
if config.act_dtype == "fp16":
167+
filter = xpq._get_module_type_filter(op_type)
168+
op_type_filters.append(filter)
169+
for op_name, config in op_name_configs.items():
170+
if config.act_dtype == "fp16":
171+
filter = xpq._get_module_name_filter(op_name)
172+
op_name_filters.append(filter)
173+
node_set_from_user_config = set()
174+
all_filters = op_type_filters + op_name_filters
175+
for node in gm.graph.nodes:
176+
if any([filter(node) for filter in all_filters]):
177+
node_set_from_user_config.add(node)
178+
return node_set_from_user_config
179+
180+
181+
def get_half_precision_node_set(gm, config):
182+
"""Intersection between `unquantized_node_set` and `node_set_from_user_config`"""
183+
# TODO: implement it, current return all unquantized_node_set
184+
185+
node_set_from_user_config = _parse_node_candidate_set_from_user_config(config, gm)
186+
unquantized_node_set = get_unquantized_node_set(gm)
187+
possible_node_set = unquantized_node_set.intersection(node_set_from_user_config)
188+
half_precision_node_set = set()
189+
for node in possible_node_set:
190+
if node.target in SUPPORTED_OPERATORS:
191+
half_precision_node_set.add(node)
192+
utils.logger.info(f"Found {len(half_precision_node_set)} nodes to convert to half precision.")
193+
return half_precision_node_set

neural_compressor/torch/utils/utility.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,5 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct
226226
# set global
227227
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
228228
quantizer.set_global(global_config)
229-
# set local
230-
for module_or_func_name, local_config in config.local_config.items():
231-
local_quant_config = _map_inc_config_to_torch_quant_config(local_config, is_dynamic)
232-
if isinstance(module_or_func_name, torch.nn.Module):
233-
quantizer.set_module_type_qconfig(module_or_func_name, local_quant_config)
234-
else:
235-
quantizer.set_function_type_qconfig(module_or_func_name, local_quant_config)
229+
# Skip the local config for now (need torch 2.4)
236230
return quantizer
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
import torch
3+
import torch.testing._internal.common_quantization as torch_test_quant_common
4+
5+
from neural_compressor.torch import export
6+
from neural_compressor.torch import utils as torch_utils
7+
from neural_compressor.torch.algorithms.pt2e_quant import half_precision_rewriter
8+
9+
10+
class TestHalfPrecisionConverter(torch_test_quant_common.QuantizationTestCase):
11+
12+
@staticmethod
13+
def build_simple_torch_model_and_example_inputs():
14+
class SimpleModel(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.fc1 = torch.nn.Linear(10, 20)
18+
self.fc2 = torch.nn.Linear(20, 10)
19+
20+
def forward(self, x: torch.Tensor) -> torch.Tensor:
21+
x = self.fc1(x)
22+
x = torch.nn.functional.relu(x)
23+
x = self.fc2(x)
24+
return x
25+
26+
model = SimpleModel()
27+
example_inputs = (torch.randn(10, 10),)
28+
return model, example_inputs
29+
30+
@pytest.mark.skipif(
31+
torch_utils.get_torch_version() <= torch_utils.TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0"
32+
)
33+
def test_quantizer_on_simple_model(self):
34+
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
35+
exported_model = export.export_model_for_pt2e_quant(model=model, example_inputs=example_inputs)
36+
print("Exported model:")
37+
exported_model.print_readable()
38+
unquantized_node_set = half_precision_rewriter.get_unquantized_node_set(exported_model)
39+
print("Before apply half precision rewriter:")
40+
exported_model.print_readable(True)
41+
half_precision_rewriter.transformation(exported_model, unquantized_node_set)
42+
print("After apply half precision rewriter:")
43+
exported_model.print_readable(True)
44+
expected_node_occurrence = {
45+
# 4 `aten.to` for each `aten.linear`
46+
torch.ops.aten.to.dtype: 8,
47+
torch.ops.aten.linear.default: 2,
48+
}
49+
expected_node_occurrence = {
50+
torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items()
51+
}
52+
self.checkGraphModuleNodes(exported_model, expected_node_occurrence=expected_node_occurrence)

0 commit comments

Comments
 (0)