-
Notifications
You must be signed in to change notification settings - Fork 236
Description
Motivation
-
To design and implement the better quantization part of MMRazor with community.
-
Collect more requirements and suggestions before releasing quantization by RFC (Request for Comments).
Overview
MMRazor quantization will be an algorithm platform not just provide basic quantization function api. We hope it will help us in the following ways:
-
Compress and deploy your model faster.
-
Producing better models with our quantization algorithms.
-
Implement some novel quantization algorithms easier.
Goals
-
Support implementing mainstream QAT and PTQ algorithms, such as LSQ, Adaround and so on.
-
Support complete working pipeline from quantization to deployment. You can deploy quantized models on multiple backends with mmdploy.
-
Adaptive OpenMMLab 2.0. Thus it can unified support OpenMMLab upstream repositories without extra code.
-
Easier to use. You can quantize your model just by modifying config and running script, rather than modify your source model.
Algorithms
We plan to support some quantization algorithms in future as follows. Welcome to propose your requirements.
QAT
-
LSQ
-
LSQ+
-
IAO
......
PTQ
-
Adaround
-
BRECQ
-
QDrop
......
Main features
We list some main features to be supported in future. Welcome to comment.
-
Quantization type: QAT and PTQ(static/dynamic)
-
Quantization bits: 1 ~ 32 Note: 1 bit is not binaryzation, just common quantization.
-
Quantization methods (uniform quantization):
- per_tensor / per_channel
- symmetry / asymmetry
- FP_scale / Pot_scale (power of two)
-
Multiple backends:
- TensorRT
- SNPE
- ncnn
- .....
Some algorithms and features to be supported will be implemented in the next several versions due to lack of manpower, welcome to create PRs to speed up development.
Most features will be released in the first release, except dynamic quantization and more backends supporting. According to quantization algorithms, we will release them by ranks in the next two versions.
Release plan
We will release our first version in December 2022 if everything goes well.
Design and Implement
We will extend and develop to implement our design based on PyTorch basic quantization function api and torch.fx. So some modules in PyTorch will be inherited and also some new modules will be created.
User-friendly config
We will use Qscheme to convert user-friendly config to API oriented parameters. Demo config is as follows.
_base_ = [
'mmcls::resnet/resnet18_8xb32_in1k.py'
]
model = dict(
_delete_=True,
type='mmrazor.GeneralQuant',
architecture=_base_.model,
quantizer=dict(
type='mmrazor.CustomQuantizer',
is_qat=False,
# `skipped_methods` is to trace model automatically by skipping
# these untraced method.
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'],
qconfig=dict(
qtype='affine',
w_observer=dict(type='mmrazor.MSEObserver'),
a_observer=dict(type='mmrazor.EMAMSEObserver'),
w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'),
a_fake_quant=dict(type='mmrazor.FakeQuantize'),
w_qscheme=dict(
bit=2,
is_symmetry=False,
is_per_channel=True,
is_pot_scale=False,
),
a_qscheme=dict(
bit=4,
is_symmetry=False,
is_per_channel=False,
is_pot_scale=False),
)
)
)Usage
Quantization algorithms' entrance is like other model compression algorithms as follows.
QAT: tools/train.py
PTQ: tools/test.py
Deploy quantized model's entrance is mmdeploy/tools/deploy.py. So you can just run the following commands to implement the pipeline from quantization to deploy
python mmrazor/tools/train.py (test.py)
python mmdeploy/tools/deploy.pyFor more details about the above commands, please refer to the quantization document to be released.
Core modules
In forward, they will update the statistics of the observed Tensor. And they should provide a calculate_qparams function that computes the quantization parameters given the collected statistics.
from torch.ao.quantization.observer import UniformQuantizationObserverBase
class BaseObserver(UniformQuantizationObserverBase):
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(
self,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
quant_min=None,
quant_max=None,
ch_axis=-1,
is_pot_scale=False,
factory_kwargs=None,
eps=torch.finfo(torch.float32).eps) -> None:
super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max,
factory_kwargs, eps)
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
self.ch_axis = ch_axis
self.is_pot_scale = is_pot_scale
@torch.jit.export
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Calculates the quantization parameters."""
pass
@torch.jit.export
def extra_repr(self):
pass
@torch.jit.export
def reset_min_max_vals(self):
"""Resets the min/max values."""
passIn forward, they will update the statistics of the observed Tensor and fake quantize the input. They should also provide a calculate_qparams function that computes the quantization parameters given the collected statistics.
In fake quantize, you can implement some algorithms' special operations.
from torch.ao.quantization import FakeQuantizeBase
from mmrazor.registry import MODELS
@MODELS.register_module()
class FakeQuantize(FakeQuantizeBase):
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self, observer, **observer_kwargs):
super().__init__()
self.activation_post_process = observer(**observer_kwargs)
self.quant_min = self.activation_post_process.quant_min
self.quant_max = self.activation_post_process.quant_max
if _is_float_qparams(self.activation_post_process.qscheme):
zero_point_dtype = torch.float
else:
zero_point_dtype = torch.int
self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype))
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = self.activation_post_process.ch_axis \
if hasattr(self.activation_post_process, 'ch_axis') else -1
assert _is_per_channel(self.qscheme) or \
_is_per_tensor(self.qscheme), \
'Only per channel and per tensor quantization are supported in fake quantize' + \
' got qscheme: ' + str(self.qscheme)
self.is_per_channel = _is_per_channel(self.qscheme)
bitrange = torch.tensor(self.quant_max - self.quant_min + 1).double()
self.bitwidth = int(torch.log2(bitrange).item())
self.is_pot_scale = self.activation_post_process.is_pot_scale
self.is_symmetric_quant = _is_symmetric_quant(self.qscheme)
@torch.jit.export
def calculate_qparams(self):
return self.activation_post_process.calculate_qparams()
def forward(self, X):
if self.observer_enabled[0] == 1:
pass
if self.fake_quant_enabled[0] == 1:
pass
return X
@torch.jit.export
def extra_repr(self):
pass
def _save_to_state_dict(self, destination, prefix, keep_vars):
pass
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
passThey implement some core quantization function APIs for algorithm, such as qconfig_convert, prepare,convert_model, fuse_model, and so on. What is more, different quantizers can deal with different backends to be deployed, thus we can configure it in the config for different backends.
@MODELS.register_module()
class CustomQuantizer(BaseModule):
def __init__(self,
qconfig=DefalutQconfigs['default'],
is_qat=True,
skipped_methods=None,
prepare_custom_config_dict=None,
convert_custom_config_dict=None,
equalization_qconfig_dict=None,
_remove_qconfig=True,
init_cfg=None):
super().__init__(init_cfg)
if self.check_qconfig(qconfig):
qconfig = self.qconfig_convert(qconfig)
self.qconfig_dict = {"": qconfig}
else:
raise ValueError('qconfig is incorrect!')
if prepare_custom_config_dict is None:
self.prepare_custom_config_dict = {}
else:
self.prepare_custom_config_dict = prepare_custom_config_dict
if convert_custom_config_dict is None:
self.convert_custom_config_dict = {}
else:
self.convert_custom_config_dict = convert_custom_config_dict
if equalization_qconfig_dict is None:
self.equalization_qconfig_dict = {}
else:
self.equalization_qconfig_dict = equalization_qconfig_dict
check_is_valid_qconfig_dict(self.qconfig_dict)
check_is_valid_prepare_custom_config_dict(self.prepare_custom_config_dict)
check_is_valid_convert_custom_config_dict(self.convert_custom_config_dict)
check_is_valid_qconfig_dict(self.equalization_qconfig_dict)
self.is_qat = is_qat
self.skipped_methods = skipped_methods
self._remove_qconfig = _remove_qconfig
self.tracer = self.build_tracer()
def prepare(self, model, graph_module):
pass
def convert(self, graph_module):
pass
def qconfig_convert(self, qconfig):
pass
def build_tracer(self):
pass
def fuse_model(self, graph_module):
pass
.....
They will provide some core APIs for Quantization Loops to implement quantization pipelines. Such as calib_step,prepare,convert and so on. And Algorithms also maintain traced graphs and forward with graphs.
They inherited mmengine's TrainLoop and TestLoop, adding some core quantization steps, such as calibrate,preprare,convert. There are also some special steps for some quantization algorithms, such as subgraph reconstruction.
How to trace the model automatically
Because torch.fx has its own limitations, some models' forward can not be traced when there are some special cases in forward, such as dynamic judgment.
For tracing the model automatically, we custom a CustomTracer and UntracedMethodRegistry. UntracedMethodRegistry can be used as a decorator to make decorated methods skipped by CustomTracer. What is more, methods to be skipped can be configured in our configs. Please refer to the chapter User-friendly config to learn about its usage.
So the solution is as follows.
-
Collect these untraceable codes to a function or a method and make the rest of the pipeline traceable. In OpenMMLab 2.0, we refactored some model interfaces to adapt torch.fx preliminary.
-
Specified these methods to be skipped in our configs.
WIP code
For more details about the implementation, please refer to the branch: https://github.com/open-mmlab/mmrazor/tree/quantize
Note:
The quantize branch is in development, modifying code will happen at any time.