|
| 1 | +# Copyright (c) 2024 Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from dataclasses import dataclass |
| 16 | +from functools import partial |
| 17 | +from typing import Any, Callable, Dict, List, Tuple |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq |
| 21 | +import torch.ao.quantization.quantizer.xnnpack_quantizer as xpq |
| 22 | +from torch.fx import subgraph_rewriter |
| 23 | +from torch.fx.experimental.proxy_tensor import make_fx |
| 24 | +from torch.fx.subgraph_rewriter import Match |
| 25 | +from typing_extensions import TypeAlias |
| 26 | + |
| 27 | +from neural_compressor.common import utils |
| 28 | + |
| 29 | +# ============================================================================= |
| 30 | +# Search and replace patterns |
| 31 | +# ============================================================================= |
| 32 | +TorchFuncType: TypeAlias = Callable[..., Any] |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class PatternPair: |
| 37 | + fn: TorchFuncType |
| 38 | + search_pattern: torch.fx.GraphModule |
| 39 | + replace_pattern: torch.fx.GraphModule |
| 40 | + |
| 41 | + |
| 42 | +# key: torch func |
| 43 | +# value: the tuple of args |
| 44 | +FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, Tuple[torch.Tensor, ...]] |
| 45 | + |
| 46 | + |
| 47 | +# Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16 |
| 48 | +# TODO: complete the mapping |
| 49 | +FN_ARGS_MAPPING: FuncArgsMappingType = { |
| 50 | + torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias |
| 51 | + torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias |
| 52 | +} |
| 53 | +# TODO: complete the mapping |
| 54 | +FN_ATEN_OPS_MAPPING = { |
| 55 | + torch.nn.functional.linear: torch.ops.aten.linear.default, |
| 56 | +} |
| 57 | + |
| 58 | +SUPPORTED_OPERATORS = FN_ATEN_OPS_MAPPING.values() |
| 59 | + |
| 60 | + |
| 61 | +PatternRegistryType: TypeAlias = Dict[TorchFuncType, PatternPair] |
| 62 | +HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {torch.float16: {}, torch.bfloat16: {}} |
| 63 | + |
| 64 | +# FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16] |
| 65 | +# BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16] |
| 66 | + |
| 67 | + |
| 68 | +def pattern_factory(fn: TorchFuncType, fn_arg: Tuple[torch.Tensor, ...], target_dtype: torch.dtype = torch.float16): |
| 69 | + """Create a search, replace pattern and filter functions for a given torch function and its arguments.""" |
| 70 | + assert target_dtype in [ |
| 71 | + torch.float16, |
| 72 | + torch.bfloat16, |
| 73 | + ], f"target_dtype should either be `torch.float16` or `torch.bfloat16`, but got {target_dtype}" |
| 74 | + |
| 75 | + def replace_fn_wrapper(fn_args, fn): |
| 76 | + converted_args = [arg.to(target_dtype) for arg in fn_args] |
| 77 | + target_dtype_out = fn(*converted_args) |
| 78 | + return target_dtype_out.float() |
| 79 | + |
| 80 | + replace_fn = partial(replace_fn_wrapper, fn=fn) |
| 81 | + |
| 82 | + search_pattern_gm = make_fx(fn, pre_dispatch=True)(*fn_arg) |
| 83 | + # TODO: double-check `*fn_args` or `fn_args` |
| 84 | + replace_pattern_gm = make_fx(replace_fn, pre_dispatch=True)(fn_arg) |
| 85 | + |
| 86 | + pattern_pair = PatternPair(fn, search_pattern_gm, replace_pattern_gm) |
| 87 | + |
| 88 | + return pattern_pair |
| 89 | + |
| 90 | + |
| 91 | +def _register_pattern_pair(dtype: torch.dtype) -> None: |
| 92 | + for fn, fn_args in FN_ARGS_MAPPING.items(): |
| 93 | + pattern_pair = pattern_factory(fn, fn_args) |
| 94 | + HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair |
| 95 | + utils.logger.info( |
| 96 | + f"Registered {len(HALF_PRECISION_PATTERN_REGISTRY[dtype])} search and replace patterns for {dtype}." |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +_register_pattern_pair(torch.float16) |
| 101 | + |
| 102 | + |
| 103 | +def get_filter_fn(node_list, fn): |
| 104 | + target_op = FN_ATEN_OPS_MAPPING[fn] |
| 105 | + |
| 106 | + def is_target_node_in_candidate_list(match, original_graph, pattern_graph): |
| 107 | + """Filter the node with target operator in match and check if it is in `node_list`.""" |
| 108 | + target_node = None |
| 109 | + for node in pattern_graph.nodes: |
| 110 | + if node.target == target_op: |
| 111 | + target_node = node |
| 112 | + break |
| 113 | + if target_node is None: |
| 114 | + return False |
| 115 | + matched_node = match.nodes_map[target_node] |
| 116 | + return matched_node in node_list |
| 117 | + |
| 118 | + return is_target_node_in_candidate_list |
| 119 | + |
| 120 | + |
| 121 | +def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPair, node_list): |
| 122 | + filter_fn = get_filter_fn(node_list, pattern_pair.fn) |
| 123 | + match_and_replacements = subgraph_rewriter.replace_pattern_with_filters( |
| 124 | + gm=gm, |
| 125 | + pattern=pattern_pair.search_pattern, |
| 126 | + replacement=pattern_pair.replace_pattern, |
| 127 | + match_filters=[filter_fn], |
| 128 | + ) |
| 129 | + utils.logger.info(f"Found {len(match_and_replacements)} matches.") |
| 130 | + |
| 131 | + match_list = [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements] |
| 132 | + return match_list |
| 133 | + |
| 134 | + |
| 135 | +def get_unquantized_node_set(gm: torch.fx.GraphModule): |
| 136 | + unquantized_node_set = set() |
| 137 | + for node in gm.graph.nodes: |
| 138 | + if meta := getattr(node, "meta"): |
| 139 | + if quantization_annotation := meta.get(xiq.QUANT_ANNOTATION_KEY): |
| 140 | + if quantization_annotation._annotated: |
| 141 | + continue |
| 142 | + unquantized_node_set.add(node) |
| 143 | + return unquantized_node_set |
| 144 | + |
| 145 | + |
| 146 | +def transformation(gm: torch.fx.GraphModule, node_candidate_list: List[str], target_dtype: torch.dtype = torch.float16): |
| 147 | + """Convert the nodes in `node_candidate_list` to `target_dtype` if possible.""" |
| 148 | + for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values(): |
| 149 | + apply_single_pattern_pair(gm, pattern_pair, node_candidate_list) |
| 150 | + utils.logger.info("Half precision conversion is done:") |
| 151 | + gm.print_readable(True) |
| 152 | + |
| 153 | + |
| 154 | +# ============================================================================= |
| 155 | +# Utils to parse the node candidate set for half precision conversion |
| 156 | +# ============================================================================= |
| 157 | + |
| 158 | + |
| 159 | +def _parse_node_candidate_set_from_user_config(config, gm): |
| 160 | + """Parse the node candidate set from user config.""" |
| 161 | + op_type_configs, op_name_configs = config._get_op_name_op_type_config() |
| 162 | + op_type_filters = [] |
| 163 | + op_name_filters = [] |
| 164 | + for op_type_name, config in op_type_configs.items(): |
| 165 | + op_type = getattr(torch.nn, op_type_name) |
| 166 | + if config.act_dtype == "fp16": |
| 167 | + filter = xpq._get_module_type_filter(op_type) |
| 168 | + op_type_filters.append(filter) |
| 169 | + for op_name, config in op_name_configs.items(): |
| 170 | + if config.act_dtype == "fp16": |
| 171 | + filter = xpq._get_module_name_filter(op_name) |
| 172 | + op_name_filters.append(filter) |
| 173 | + node_set_from_user_config = set() |
| 174 | + all_filters = op_type_filters + op_name_filters |
| 175 | + for node in gm.graph.nodes: |
| 176 | + if any([filter(node) for filter in all_filters]): |
| 177 | + node_set_from_user_config.add(node) |
| 178 | + return node_set_from_user_config |
| 179 | + |
| 180 | + |
| 181 | +def get_half_precision_node_set(gm, config): |
| 182 | + """Intersection between `unquantized_node_set` and `node_set_from_user_config`""" |
| 183 | + # TODO: implement it, current return all unquantized_node_set |
| 184 | + |
| 185 | + node_set_from_user_config = _parse_node_candidate_set_from_user_config(config, gm) |
| 186 | + unquantized_node_set = get_unquantized_node_set(gm) |
| 187 | + possible_node_set = unquantized_node_set.intersection(node_set_from_user_config) |
| 188 | + half_precision_node_set = set() |
| 189 | + for node in possible_node_set: |
| 190 | + if node.target in SUPPORTED_OPERATORS: |
| 191 | + half_precision_node_set.add(node) |
| 192 | + utils.logger.info(f"Found {len(half_precision_node_set)} nodes to convert to half precision.") |
| 193 | + return half_precision_node_set |
0 commit comments