Skip to content

Commit 9150181

Browse files
authored
Port torch GPTQ to 3.x (#1408)
Signed-off-by: yiliu30 <[email protected]>
1 parent bb60e33 commit 9150181

File tree

15 files changed

+1595
-121
lines changed

15 files changed

+1595
-121
lines changed

.azure-pipelines/scripts/ut/run_itrex.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ echo "run itrex ut..."
66

77
# prepare itrex
88
git clone https://github.com/intel/intel-extension-for-transformers.git /intel-extension-for-transformers
9+
cd /intel-extension-for-transformers && git rev-parse --short HEAD
910
bash /intel-extension-for-transformers/.github/workflows/script/prepare_env.sh
1011
bash /intel-extension-for-transformers/.github/workflows/script/install_binary.sh
1112

neural_compressor/common/base_config.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,16 @@ def from_dict(cls, config_dict, str2operator=None):
118118
Returns:
119119
The constructed config.
120120
"""
121-
config = cls(**config_dict.get(GLOBAL, {}))
122-
operator_config = config_dict.get(LOCAL, {})
123-
if operator_config:
124-
for op_name, op_config in operator_config.items():
125-
config.set_local(op_name, cls(**op_config))
126-
return config
121+
if GLOBAL not in config_dict and LOCAL not in config_dict:
122+
config = cls(**config_dict)
123+
return config
124+
else:
125+
config = cls(**config_dict.get(GLOBAL, {}))
126+
operator_config = config_dict.get(LOCAL, {})
127+
if operator_config:
128+
for op_name, op_config in operator_config.items():
129+
config.set_local(op_name, cls(**op_config))
130+
return config
127131

128132
@classmethod
129133
def to_diff_dict(cls, instance) -> Dict[str, Any]:
@@ -201,11 +205,11 @@ def to_config_mapping(
201205
global_config = config.global_config
202206
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
203207
for op_name, op_type in model_info:
204-
config_mapping.setdefault(op_type, OrderedDict())[op_name] = global_config
208+
config_mapping[(op_type, op_name)] = global_config
205209
if op_type in op_type_config_dict:
206-
config_mapping[op_type][op_name] = op_name_config_dict[op_type]
210+
config_mapping[(op_type, op_name)] = op_name_config_dict[op_type]
207211
if op_name in op_name_config_dict:
208-
config_mapping[op_type][op_name] = op_name_config_dict[op_name]
212+
config_mapping[(op_type, op_name)] = op_name_config_dict[op_name]
209213
return config_mapping
210214

211215
@staticmethod
@@ -234,9 +238,15 @@ def to_dict(self, params_list=[], operator2str=None):
234238
return result
235239

236240
@classmethod
237-
def from_dict(cls, config_dict, str2operator=None):
238-
# TODO(Yi)
239-
pass
241+
def from_dict(cls, config_dict: OrderedDict[str, Dict], config_registry: Dict[str, BaseConfig]):
242+
assert len(config_dict) >= 1, "The config dict must include at least one configuration."
243+
num_configs = len(config_dict)
244+
name, value = next(iter(config_dict.items()))
245+
config = config_registry[name].from_dict(value)
246+
for _ in range(num_configs - 1):
247+
name, value = next(iter(config_dict.items()))
248+
config += config_registry[name].from_dict(value)
249+
return config
240250

241251
def to_json_string(self, use_diff: bool = False) -> str:
242252
return json.dumps(self.to_dict(), indent=2) + "\n"

neural_compressor/common/utility.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@
2626
BASE_CONFIG = "base_config"
2727
COMPOSABLE_CONFIG = "composable_config"
2828
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
29+
GPTQ = "gptq"
2930
DUMMY_CONFIG = "dummy_config"

neural_compressor/torch/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414

1515
from neural_compressor.torch.utils import register_algo
16-
from neural_compressor.torch.algorithms import rtn_quantize_entry
16+
from neural_compressor.torch.algorithms import rtn_quantize_entry, gptq_quantize_entry
1717

1818
from neural_compressor.torch.quantization import (
1919
quantize,
2020
RTNWeightQuantConfig,
2121
get_default_rtn_config,
2222
DummyConfig,
2323
get_default_dummy_config,
24+
GPTQConfig,
25+
get_default_gptq_config,
2426
)

neural_compressor/torch/algorithms/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
# limitations under the License.
1414

1515

16-
from neural_compressor.torch.algorithms.rtn_quantize import rtn_quantize_entry
16+
from neural_compressor.torch.algorithms.weight_only_algos import rtn_quantize_entry
17+
from neural_compressor.torch.algorithms.weight_only_algos import gptq_quantize_entry

0 commit comments

Comments
 (0)