Skip to content

Commit 071ab31

Browse files
authored
Enable the Combination of Multiple Algorithms within a Single Model (#1616)
Signed-off-by: yiliu30 <[email protected]>
1 parent ec91109 commit 071ab31

File tree

4 files changed

+154
-10
lines changed

4 files changed

+154
-10
lines changed

neural_compressor/common/base_config.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,20 @@ def __repr__(self) -> str:
457457
return f"{self.__class__.__name__} {self.to_json_string()}"
458458

459459
def to_config_mapping(
460-
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
460+
self, config_list: List[BaseConfig] = None, model_info: Dict[str, Any] = None
461461
) -> OrderedDict[str, BaseConfig]:
462-
return super().to_config_mapping(self.config_list, model_info)
462+
config_mapping = OrderedDict()
463+
for config in self.config_list:
464+
global_config = config.global_config
465+
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
466+
single_config_model_info = model_info.get(config.name, None)
467+
for op_name, op_type in single_config_model_info:
468+
if op_type in op_type_config_dict:
469+
config_mapping[(op_name, op_type)] = op_name_config_dict[op_type]
470+
for op_name_pattern in op_name_config_dict:
471+
if re.match(op_name_pattern, op_name):
472+
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
473+
return config_mapping
463474

464475
@classmethod
465476
def register_supported_configs(cls):
@@ -471,6 +482,12 @@ def get_config_set_for_tuning(cls) -> None:
471482
# TODO (Yi) handle the composable config in `tuning_config`
472483
return None
473484

485+
def get_model_info(self, model, *args, **kwargs):
486+
model_info_dict = dict()
487+
for config in self.config_list:
488+
model_info_dict.update({config.name: config.get_model_info(model, *args, **kwargs)})
489+
return model_info_dict
490+
474491

475492
def get_all_config_set_from_config_registry(fwk_name: str) -> Union[BaseConfig, List[BaseConfig]]:
476493
all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls_by_fwk_name(fwk_name)

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def rtn_entry(
4141
# rebuild weight_config for rtn_quantize function
4242
weight_config = {}
4343
for (op_name, op_type), quant_config in configs_mapping.items():
44+
if quant_config.name != RTN:
45+
continue
4446
weight_config[op_name] = {
4547
"dtype": quant_config.dtype,
4648
"bits": quant_config.bits,
@@ -74,6 +76,8 @@ def gptq_entry(
7476
# rebuild weight_config for gptq_quantize function
7577
weight_config = {}
7678
for (op_name, op_type), quant_config in configs_mapping.items():
79+
if quant_config.name != GPTQ:
80+
continue
7781
weight_config[op_name] = {
7882
"dtype": quant_config.dtype,
7983
"bits": quant_config.bits,
@@ -120,6 +124,8 @@ def static_quant_entry(
120124
cfgs = deepcopy(configs_mapping)
121125
quant_config_mapping["op"] = cfgs
122126
for (op_name, op_type), cfg in cfgs.items():
127+
if cfg.name != STATIC_QUANT:
128+
continue
123129
quant_config_mapping["op"][(op_name, op_type)] = {
124130
"weight": {
125131
"dtype": cfg.w_dtype,
@@ -161,6 +167,8 @@ def awq_quantize_entry(
161167

162168
weight_config = {}
163169
for (op_name, op_type), op_config in configs_mapping.items():
170+
if op_config.name != AWQ:
171+
continue
164172
if op_config.dtype == "fp32":
165173
weight_config[op_name] = {
166174
"bits": -1,

test/3x/common/test_common.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,40 @@
4040

4141
logger = Logger().get_logger()
4242

43-
from typing import Any, Callable, List, Optional, Tuple, Union
43+
from typing import Any, List, Optional, Tuple, Union
4444

4545
from neural_compressor.common.base_config import (
4646
BaseConfig,
47-
ComposableConfig,
47+
config_registry,
4848
get_all_config_set_from_config_registry,
4949
register_config,
50+
register_supported_configs_for_fwk,
5051
)
5152
from neural_compressor.common.base_tuning import ConfigLoader, ConfigSet, SequentialSampler
5253
from neural_compressor.common.tuning_param import TuningParam
5354
from neural_compressor.common.utils import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE
5455

5556
PRIORITY_FAKE_ALGO = 100
5657
FAKE_CONFIG_NAME = "fake"
58+
PRIORITY_FAKE_ALGO_1 = 90
59+
FAKE_CONFIG_NAME_1 = "fake_one"
5760
DEFAULT_WEIGHT_BITS = [4, 6]
5861

5962
FAKE_FRAMEWORK_NAME = "FAKE_FWK"
6063

64+
FAKE_MODEL_INFO = [("OP1_NAME", "OP_TYPE1"), ("OP2_NAME", "OP_TYPE1"), ("OP3_NAME", "OP_TYPE2")]
65+
66+
67+
class FakeModel:
68+
def __init__(self) -> None:
69+
self.name = "fake_model"
70+
71+
def __call__(self, x) -> Any:
72+
return x
73+
74+
def __repr__(self) -> str:
75+
return "FakeModel"
76+
6177

6278
@register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME, priority=PRIORITY_FAKE_ALGO)
6379
class FakeAlgoConfig(BaseConfig):
@@ -102,17 +118,14 @@ def register_supported_configs(cls) -> List:
102118
pass
103119

104120
@staticmethod
105-
def get_model_info(model: Any) -> List[Tuple[str, Callable]]:
106-
pass
121+
def get_model_info(model: Any) -> List[Tuple[str, Any]]:
122+
return FAKE_MODEL_INFO
107123

108124
@classmethod
109125
def get_config_set_for_tuning(cls) -> Union[None, "FakeAlgoConfig", List["FakeAlgoConfig"]]:
110126
return FakeAlgoConfig(weight_bits=DEFAULT_WEIGHT_BITS)
111127

112128

113-
FakeAlgoConfig.register_supported_configs()
114-
115-
116129
def get_default_fake_config() -> FakeAlgoConfig:
117130
"""Generate the default fake config.
118131
@@ -122,10 +135,64 @@ def get_default_fake_config() -> FakeAlgoConfig:
122135
return FakeAlgoConfig()
123136

124137

138+
@register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME_1, priority=PRIORITY_FAKE_ALGO_1)
139+
class FakeAlgoOneConfig(BaseConfig):
140+
"""Config class for fake algo."""
141+
142+
supported_configs: List = []
143+
params_list = [
144+
"weight_dtype",
145+
"weight_bits",
146+
TuningParam("target_op_type_list", tunable_type=List[List[str]]),
147+
]
148+
name = FAKE_CONFIG_NAME_1
149+
150+
def __init__(
151+
self,
152+
weight_dtype: str = "int",
153+
weight_bits: int = 4,
154+
target_op_type_list: List[str] = ["Conv", "Gemm"],
155+
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
156+
):
157+
"""Init fake config.
158+
159+
Args:
160+
weight_dtype (str): Data type for weights, default is "int".
161+
weight_bits (int): Number of bits used to represent weights, default is 4.
162+
"""
163+
super().__init__(white_list=white_list)
164+
self.weight_bits = weight_bits
165+
self.weight_dtype = weight_dtype
166+
self.target_op_type_list = target_op_type_list
167+
self._post_init()
168+
169+
def to_dict(self):
170+
return super().to_dict()
171+
172+
@classmethod
173+
def from_dict(cls, config_dict):
174+
return super(FakeAlgoOneConfig, cls).from_dict(config_dict=config_dict)
175+
176+
@classmethod
177+
def register_supported_configs(cls) -> List:
178+
pass
179+
180+
@staticmethod
181+
def get_model_info(model: Any) -> List[Tuple[str, Any]]:
182+
return FAKE_MODEL_INFO
183+
184+
@classmethod
185+
def get_config_set_for_tuning(cls) -> Union[None, "FakeAlgoOneConfig", List["FakeAlgoOneConfig"]]:
186+
return FakeAlgoOneConfig(weight_bits=DEFAULT_WEIGHT_BITS)
187+
188+
125189
def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
126190
return get_all_config_set_from_config_registry(fwk_name=FAKE_FRAMEWORK_NAME)
127191

128192

193+
register_supported_configs_for_fwk(fwk_name=FAKE_FRAMEWORK_NAME)
194+
195+
129196
class TestBaseConfig(unittest.TestCase):
130197
@classmethod
131198
def setUpClass(self):
@@ -143,7 +210,7 @@ def test_api(self):
143210
fake_default_config = get_default_fake_config()
144211
self.assertEqual(fake_default_config.weight_dtype, "int")
145212
config_set = get_all_config_set()
146-
self.assertEqual(len(config_set), 1)
213+
self.assertEqual(len(config_set), len(config_registry.get_all_config_cls_by_fwk_name(FAKE_FRAMEWORK_NAME)))
147214
self.assertEqual(config_set[0].weight_bits, DEFAULT_WEIGHT_BITS)
148215

149216
def test_config_expand_complex_tunable_type(self):
@@ -154,6 +221,18 @@ def test_config_expand_complex_tunable_type(self):
154221
for i in range(len(configs_list)):
155222
self.assertEqual(configs_list[i].target_op_type_list, target_op_type_list_options[i])
156223

224+
def test_mixed_two_algos(self):
225+
model = FakeModel()
226+
OP1_NAME = "OP1_NAME"
227+
OP2_NAME = "OP2_NAME"
228+
fake_config = FakeAlgoConfig(weight_bits=4, white_list=[OP1_NAME])
229+
fake1_config = FakeAlgoOneConfig(weight_bits=2, white_list=[OP2_NAME])
230+
mixed_config = fake_config + fake1_config
231+
model_info = mixed_config.get_model_info(model)
232+
config_mapping = mixed_config.to_config_mapping(model_info=model_info)
233+
self.assertIn(OP1_NAME, [op_info[0] for op_info in config_mapping])
234+
self.assertIn(OP2_NAME, [op_info[0] for op_info in config_mapping])
235+
157236

158237
class TestConfigSet(unittest.TestCase):
159238
def setUp(self):
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import copy
2+
from unittest.mock import patch
3+
4+
import pytest
5+
import torch
6+
import transformers
7+
8+
from neural_compressor.common.utils import logger
9+
from neural_compressor.torch.quantization import GPTQConfig, RTNConfig, quantize
10+
11+
12+
def run_fn(model):
13+
# GPTQ uses ValueError to reduce computation when collecting input data of the first block
14+
# It's special for UTs, no need to add this wrapper in examples.
15+
with pytest.raises(ValueError):
16+
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
17+
model(torch.tensor([[40, 50, 60]], dtype=torch.long))
18+
19+
20+
class TestMixedTwoAlgo:
21+
def test_mixed_gptq_and_rtn(self):
22+
with patch.object(logger, "info") as mock_info:
23+
rtn_config = RTNConfig(white_list=["lm_head"])
24+
gptq_config = GPTQConfig(double_quant_bits=4, white_list=["transformer.*"])
25+
combined_config = rtn_config + gptq_config
26+
logger.info(combined_config)
27+
28+
self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
29+
"hf-internal-testing/tiny-random-GPTJForCausalLM",
30+
)
31+
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long)
32+
# record label for comparison
33+
out_original_model = self.tiny_gptj(self.example_inputs)[0]
34+
model = copy.deepcopy(self.tiny_gptj)
35+
q_model = quantize(model, combined_config, run_fn=run_fn)
36+
out_q_model = q_model(self.example_inputs)[0]
37+
rtn_log = "Start to apply rtn on the model."
38+
gptq_log = "Start to apply gptq on the model."
39+
assert rtn_log in [_call[0][0] for _call in mock_info.call_args_list]
40+
assert gptq_log in [_call[0][0] for _call in mock_info.call_args_list]

0 commit comments

Comments
 (0)