|
21 | 21 | import torch |
22 | 22 |
|
23 | 23 | from neural_compressor.torch.algorithms.base_algorithm import Quantizer |
24 | | -from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger |
| 24 | +from neural_compressor.torch.utils import get_accelerator, get_model_device, is_transformers_imported, logger |
25 | 25 |
|
26 | 26 | from .modules import MulLinear, TEQLinearFakeQuant |
27 | 27 | from .utility import get_module, quant_tensor, set_module |
@@ -265,18 +265,70 @@ def transform(self): |
265 | 265 | set_module(self.model, n, m.orig_layer) |
266 | 266 |
|
267 | 267 | @torch.no_grad() |
268 | | - def quantize(self): |
| 268 | + def quantize(self, **kwargs): |
269 | 269 | """quantization.""" |
270 | | - |
271 | | - for n, m in self.model.named_modules(): |
272 | | - if self.weight_config.get(n) is None: # pragma: no cover |
273 | | - logger.info(f"quantize layer {n} not in weight config, skip.") |
| 270 | + use_optimum_format = kwargs.get("use_optimum_format", True) |
| 271 | + device = get_accelerator().current_device_name() |
| 272 | + model_device = get_model_device(self.model) # return model on the same device |
| 273 | + model = self.model |
| 274 | + for name, m in model.named_modules(): |
| 275 | + if self.weight_config.get(name) is None: # pragma: no cover |
| 276 | + logger.info(f"quantize layer {name} not in weight config, skip.") |
274 | 277 | continue |
275 | | - num_bits = self.weight_config[n]["bits"] |
276 | | - group_size = self.weight_config[n]["group_size"] |
277 | | - scheme = self.weight_config[n]["scheme"] |
| 278 | + num_bits = self.weight_config[name]["bits"] |
| 279 | + group_size = self.weight_config[name]["group_size"] |
| 280 | + scheme = self.weight_config[name]["scheme"] |
| 281 | + group_dim = self.weight_config[name].get("group_dim", 1) |
| 282 | + # for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight. |
| 283 | + if is_transformers_imported(): |
| 284 | + transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D)) |
| 285 | + else: # pragma: no cover |
| 286 | + transpose = group_dim == 0 |
| 287 | + if transpose: # pragma: no cover |
| 288 | + weight = m.weight.detach().T.contiguous() |
| 289 | + else: |
| 290 | + weight = m.weight.detach() |
278 | 291 | if isinstance(m, torch.nn.Linear): # pragma: no cover |
279 | | - quant_tensor(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme) |
| 292 | + int_weight, scale, zp = quant_tensor( |
| 293 | + weight.data, |
| 294 | + num_bits=num_bits, |
| 295 | + group_size=group_size, |
| 296 | + scheme=scheme, |
| 297 | + return_int=True, |
| 298 | + ) |
| 299 | + int_weight = int_weight.t_().contiguous() if transpose else int_weight |
| 300 | + scale = scale.t_().contiguous() if transpose else scale |
| 301 | + zp = zp.t_().contiguous() if transpose and zp is not None else zp |
| 302 | + if isinstance(m, torch.nn.Linear): |
| 303 | + in_features = m.in_features |
| 304 | + out_features = m.out_features |
| 305 | + elif is_transformers_imported() and isinstance(m, transformers.Conv1D): |
| 306 | + in_features = m.weight.shape[0] |
| 307 | + out_features = m.weight.shape[1] |
| 308 | + int_weight = int_weight.t_().contiguous() |
| 309 | + scale = scale.t_().contiguous() |
| 310 | + zp = zp.t_().contiguous() if zp is not None else zp |
| 311 | + from .modules import WeightOnlyLinear |
| 312 | + |
| 313 | + new_module = WeightOnlyLinear( |
| 314 | + in_features, |
| 315 | + out_features, |
| 316 | + bits=num_bits, |
| 317 | + group_size=group_size, |
| 318 | + zp=zp is not None, |
| 319 | + bias=m.bias is not None, |
| 320 | + use_optimum_format=use_optimum_format, |
| 321 | + device=device, |
| 322 | + ) |
| 323 | + new_module.pack(int_weight, scale, zp, m.bias) |
| 324 | + if name == "": |
| 325 | + return new_module |
| 326 | + else: |
| 327 | + set_module(model, name, new_module) |
| 328 | + # Move modules back to the model device layer-by-layer |
| 329 | + m.to(model_device) |
| 330 | + new_module.to(model_device) |
| 331 | + self.model = model |
280 | 332 |
|
281 | 333 | def save(self, save_scale_file="", save_state_dict_file=""): |
282 | 334 | """ |
@@ -328,6 +380,6 @@ def convert(self, model, *args: Any, **kwargs: Any): |
328 | 380 | setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None)) |
329 | 381 | self._quantizer.model = model |
330 | 382 | self._quantizer.transform() |
331 | | - self._quantizer.quantize() |
| 383 | + self._quantizer.quantize(**kwargs) |
332 | 384 | logger.info("TEQ quantizing done.") |
333 | 385 | return self._quantizer.model |
0 commit comments