Skip to content

Commit 4a45093

Browse files
authored
Add export support for TEQ (#1910)
Signed-off-by: yiliu30 <[email protected]>
1 parent 16a7b11 commit 4a45093

File tree

1 file changed

+63
-11
lines changed
  • neural_compressor/torch/algorithms/weight_only

1 file changed

+63
-11
lines changed

neural_compressor/torch/algorithms/weight_only/teq.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222

2323
from neural_compressor.torch.algorithms.base_algorithm import Quantizer
24-
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger
24+
from neural_compressor.torch.utils import get_accelerator, get_model_device, is_transformers_imported, logger
2525

2626
from .modules import MulLinear, TEQLinearFakeQuant
2727
from .utility import get_module, quant_tensor, set_module
@@ -265,18 +265,70 @@ def transform(self):
265265
set_module(self.model, n, m.orig_layer)
266266

267267
@torch.no_grad()
268-
def quantize(self):
268+
def quantize(self, **kwargs):
269269
"""quantization."""
270-
271-
for n, m in self.model.named_modules():
272-
if self.weight_config.get(n) is None: # pragma: no cover
273-
logger.info(f"quantize layer {n} not in weight config, skip.")
270+
use_optimum_format = kwargs.get("use_optimum_format", True)
271+
device = get_accelerator().current_device_name()
272+
model_device = get_model_device(self.model) # return model on the same device
273+
model = self.model
274+
for name, m in model.named_modules():
275+
if self.weight_config.get(name) is None: # pragma: no cover
276+
logger.info(f"quantize layer {name} not in weight config, skip.")
274277
continue
275-
num_bits = self.weight_config[n]["bits"]
276-
group_size = self.weight_config[n]["group_size"]
277-
scheme = self.weight_config[n]["scheme"]
278+
num_bits = self.weight_config[name]["bits"]
279+
group_size = self.weight_config[name]["group_size"]
280+
scheme = self.weight_config[name]["scheme"]
281+
group_dim = self.weight_config[name].get("group_dim", 1)
282+
# for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight.
283+
if is_transformers_imported():
284+
transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D))
285+
else: # pragma: no cover
286+
transpose = group_dim == 0
287+
if transpose: # pragma: no cover
288+
weight = m.weight.detach().T.contiguous()
289+
else:
290+
weight = m.weight.detach()
278291
if isinstance(m, torch.nn.Linear): # pragma: no cover
279-
quant_tensor(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme)
292+
int_weight, scale, zp = quant_tensor(
293+
weight.data,
294+
num_bits=num_bits,
295+
group_size=group_size,
296+
scheme=scheme,
297+
return_int=True,
298+
)
299+
int_weight = int_weight.t_().contiguous() if transpose else int_weight
300+
scale = scale.t_().contiguous() if transpose else scale
301+
zp = zp.t_().contiguous() if transpose and zp is not None else zp
302+
if isinstance(m, torch.nn.Linear):
303+
in_features = m.in_features
304+
out_features = m.out_features
305+
elif is_transformers_imported() and isinstance(m, transformers.Conv1D):
306+
in_features = m.weight.shape[0]
307+
out_features = m.weight.shape[1]
308+
int_weight = int_weight.t_().contiguous()
309+
scale = scale.t_().contiguous()
310+
zp = zp.t_().contiguous() if zp is not None else zp
311+
from .modules import WeightOnlyLinear
312+
313+
new_module = WeightOnlyLinear(
314+
in_features,
315+
out_features,
316+
bits=num_bits,
317+
group_size=group_size,
318+
zp=zp is not None,
319+
bias=m.bias is not None,
320+
use_optimum_format=use_optimum_format,
321+
device=device,
322+
)
323+
new_module.pack(int_weight, scale, zp, m.bias)
324+
if name == "":
325+
return new_module
326+
else:
327+
set_module(model, name, new_module)
328+
# Move modules back to the model device layer-by-layer
329+
m.to(model_device)
330+
new_module.to(model_device)
331+
self.model = model
280332

281333
def save(self, save_scale_file="", save_state_dict_file=""):
282334
"""
@@ -328,6 +380,6 @@ def convert(self, model, *args: Any, **kwargs: Any):
328380
setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None))
329381
self._quantizer.model = model
330382
self._quantizer.transform()
331-
self._quantizer.quantize()
383+
self._quantizer.quantize(**kwargs)
332384
logger.info("TEQ quantizing done.")
333385
return self._quantizer.model

0 commit comments

Comments
 (0)