Skip to content

Commit efea089

Browse files
Refine base config for 3.x (#1595)
* Refine base config for 3.x * fixed group_dim is 0 --------- Signed-off-by: Mengni Wang <[email protected]> Signed-off-by: yiliu30 <[email protected]> Co-authored-by: yiliu30 <[email protected]>
1 parent c6f9cca commit efea089

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

neural_compressor/common/base_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,9 @@ def to_dict(self):
210210

211211
def get_params_dict(self):
212212
result = dict()
213-
for param in self.params_list:
214-
result[param] = getattr(self, param)
213+
for param, value in self.__dict__.items():
214+
if param not in ["_global_config", "_local_config", "_white_list"]:
215+
result[param] = value
215216
return result
216217

217218
@classmethod

neural_compressor/onnxrt/quantization/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def __init__(
224224
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/quantize.py#L78
225225
"""
226226
BaseConfig.__init__(self)
227-
StaticQuantConfig.__init__(self, calibration_data_reader=None, **kwargs)
227+
kwargs.update({"calibration_data_reader": None})
228+
StaticQuantConfig.__init__(self, **kwargs)
228229
self.alpha = alpha
229230
self.folding = folding
230231
self.op_types = op_types

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ def rtn_quantize(
122122
continue
123123
logger.debug(f"RTN quantized module:{name, m}")
124124
logger.debug(log_msg)
125-
weight = m.weight.t_().contiguous() if group_dim == 0 else m.weight
125+
if group_dim == 0:
126+
weight = m.weight.t_().contiguous()
127+
else:
128+
weight = m.weight
126129
if use_mse_search:
127130
quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range)
128131
if export_compressed_model:
@@ -169,6 +172,9 @@ def rtn_quantize(
169172
full_range=use_full_range,
170173
**double_quant_config,
171174
)
172-
weight = weight.t_().contiguous() if group_dim == 0 else weight
175+
if group_dim == 0:
176+
# for group_dim is 0, we need to transpose the quantized tensor and module's weight back
177+
weight = weight.t_().contiguous()
178+
m.weight.t_().contiguous()
173179
m.weight.data.copy_(weight)
174180
return model

0 commit comments

Comments
 (0)