Skip to content

[RFC] MMRazor Quantization Design #347

@humu789

Description

@humu789

Motivation

  1. To design and implement the better quantization part of MMRazor with community.

  2. Collect more requirements and suggestions before releasing quantization by RFC (Request for Comments).

Overview

MMRazor quantization will be an algorithm platform not just provide basic quantization function api. We hope it will help us in the following ways:

  1. Compress and deploy your model faster.

  2. Producing better models with our quantization algorithms.

  3. Implement some novel quantization algorithms easier.

Goals

  1. Support implementing mainstream QAT and PTQ algorithms, such as LSQ, Adaround and so on.

  2. Support complete working pipeline from quantization to deployment. You can deploy quantized models on multiple backends with mmdploy.

  3. Adaptive OpenMMLab 2.0. Thus it can unified support OpenMMLab upstream repositories without extra code.

  4. Easier to use. You can quantize your model just by modifying config and running script, rather than modify your source model.

Algorithms

We plan to support some quantization algorithms in future as follows. Welcome to propose your requirements.

QAT

  1. LSQ

  2. LSQ+

  3. IAO

......

PTQ

  1. Adaround

  2. BRECQ

  3. QDrop

......

Main features

We list some main features to be supported in future. Welcome to comment.

  1. Quantization type: QAT and PTQ(static/dynamic)

  2. Quantization bits: 1 ~ 32 Note: 1 bit is not binaryzation, just common quantization.

  3. Quantization methods (uniform quantization):

    1. per_tensor / per_channel
    2. symmetry / asymmetry
    3. FP_scale / Pot_scale (power of two)
  4. Multiple backends:

    1. TensorRT
    2. SNPE
    3. ncnn
    4. .....

Some algorithms and features to be supported will be implemented in the next several versions due to lack of manpower, welcome to create PRs to speed up development.

Most features will be released in the first release, except dynamic quantization and more backends supporting. According to quantization algorithms, we will release them by ranks in the next two versions.

Release plan

We will release our first version in December 2022 if everything goes well.

Design and Implement

We will extend and develop to implement our design based on PyTorch basic quantization function api and torch.fx. So some modules in PyTorch will be inherited and also some new modules will be created.

User-friendly config

We will use Qscheme to convert user-friendly config to API oriented parameters. Demo config is as follows.

_base_ = [
    'mmcls::resnet/resnet18_8xb32_in1k.py'
]

model = dict(
    _delete_=True,
    type='mmrazor.GeneralQuant',
    architecture=_base_.model,
    quantizer=dict(
        type='mmrazor.CustomQuantizer',
        is_qat=False,
        # `skipped_methods` is to trace model automatically by skipping 
        # these untraced method.
        skipped_methods=[
            'mmcls.models.heads.ClsHead._get_loss', 
            'mmcls.models.heads.ClsHead._get_predictions'],
        qconfig=dict(
            qtype='affine',
            w_observer=dict(type='mmrazor.MSEObserver'),
            a_observer=dict(type='mmrazor.EMAMSEObserver'),
            w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'),
            a_fake_quant=dict(type='mmrazor.FakeQuantize'),
            w_qscheme=dict(
                bit=2,
                is_symmetry=False,
                is_per_channel=True,
                is_pot_scale=False,
            ),
            a_qscheme=dict(
                bit=4,
                is_symmetry=False,
                is_per_channel=False,
                is_pot_scale=False),
        )
    )
)

Usage

Quantization algorithms' entrance is like other model compression algorithms as follows.

QAT: tools/train.py

PTQ: tools/test.py

Deploy quantized model's entrance is mmdeploy/tools/deploy.py. So you can just run the following commands to implement the pipeline from quantization to deploy

python mmrazor/tools/train.py (test.py)
python mmdeploy/tools/deploy.py

For more details about the above commands, please refer to the quantization document to be released.

Core modules

  1. Observers

In forward, they will update the statistics of the observed Tensor. And they should provide a calculate_qparams function that computes the quantization parameters given the collected statistics.

from torch.ao.quantization.observer import UniformQuantizationObserverBase

class BaseObserver(UniformQuantizationObserverBase):

    min_val: torch.Tensor
    max_val: torch.Tensor

    def __init__(
        self,
        dtype=torch.quint8,
        qscheme=torch.per_tensor_affine,
        reduce_range=False,
        quant_min=None,
        quant_max=None,
        ch_axis=-1,
        is_pot_scale=False,
        factory_kwargs=None,
        eps=torch.finfo(torch.float32).eps) -> None:
        super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, 
            factory_kwargs, eps)
        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
        self.ch_axis = ch_axis
        self.is_pot_scale = is_pot_scale

    @torch.jit.export
    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Calculates the quantization parameters."""
        pass

    @torch.jit.export
    def extra_repr(self):
        pass

    @torch.jit.export
    def reset_min_max_vals(self):
        """Resets the min/max values."""
        pass
  1. FakeQuantizes

In forward, they will update the statistics of the observed Tensor and fake quantize the input. They should also provide a calculate_qparams function that computes the quantization parameters given the collected statistics.

In fake quantize, you can implement some algorithms' special operations.

from torch.ao.quantization import FakeQuantizeBase
from mmrazor.registry import MODELS

@MODELS.register_module()
class FakeQuantize(FakeQuantizeBase):

    scale: torch.Tensor
    zero_point: torch.Tensor

    def __init__(self, observer, **observer_kwargs):
        super().__init__()
        self.activation_post_process = observer(**observer_kwargs)
        self.quant_min = self.activation_post_process.quant_min
        self.quant_max = self.activation_post_process.quant_max
        if _is_float_qparams(self.activation_post_process.qscheme):
            zero_point_dtype = torch.float
        else:
            zero_point_dtype = torch.int
        self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
        self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype))
        self.dtype = self.activation_post_process.dtype
        self.qscheme = self.activation_post_process.qscheme
        self.ch_axis = self.activation_post_process.ch_axis \
            if hasattr(self.activation_post_process, 'ch_axis') else -1
        assert _is_per_channel(self.qscheme) or \
            _is_per_tensor(self.qscheme), \
            'Only per channel and per tensor quantization are supported in fake quantize' + \
            ' got qscheme: ' + str(self.qscheme)
        self.is_per_channel = _is_per_channel(self.qscheme)
        
        bitrange = torch.tensor(self.quant_max - self.quant_min + 1).double()
        self.bitwidth = int(torch.log2(bitrange).item())
        self.is_pot_scale = self.activation_post_process.is_pot_scale
        self.is_symmetric_quant = _is_symmetric_quant(self.qscheme)

    @torch.jit.export
    def calculate_qparams(self):
        return self.activation_post_process.calculate_qparams()

    def forward(self, X):
        if self.observer_enabled[0] == 1:
            pass

        if self.fake_quant_enabled[0] == 1:
            pass
            
        return X

    @torch.jit.export
    def extra_repr(self):
        pass

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        pass
        
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        pass
  1. Quantizers

They implement some core quantization function APIs for algorithm, such as qconfig_convert, prepare,convert_model, fuse_model, and so on. What is more, different quantizers can deal with different backends to be deployed, thus we can configure it in the config for different backends.

@MODELS.register_module()
class CustomQuantizer(BaseModule):

    def __init__(self,
                 qconfig=DefalutQconfigs['default'],
                 is_qat=True,
                 skipped_methods=None,
                 prepare_custom_config_dict=None,
                 convert_custom_config_dict=None,
                 equalization_qconfig_dict=None,
                 _remove_qconfig=True,
                 init_cfg=None):
        super().__init__(init_cfg)
        if self.check_qconfig(qconfig):
            qconfig = self.qconfig_convert(qconfig)
            self.qconfig_dict = {"": qconfig}
        else:
            raise ValueError('qconfig is incorrect!')
        
        if prepare_custom_config_dict is None:
            self.prepare_custom_config_dict = {}
        else:
            self.prepare_custom_config_dict = prepare_custom_config_dict
        if convert_custom_config_dict is None:
            self.convert_custom_config_dict = {}
        else:
            self.convert_custom_config_dict = convert_custom_config_dict
        if equalization_qconfig_dict is None:
            self.equalization_qconfig_dict = {}
        else:
            self.equalization_qconfig_dict = equalization_qconfig_dict

        check_is_valid_qconfig_dict(self.qconfig_dict)
        check_is_valid_prepare_custom_config_dict(self.prepare_custom_config_dict)
        check_is_valid_convert_custom_config_dict(self.convert_custom_config_dict)
        check_is_valid_qconfig_dict(self.equalization_qconfig_dict)
        
        self.is_qat = is_qat
        self.skipped_methods = skipped_methods
        self._remove_qconfig = _remove_qconfig
        self.tracer = self.build_tracer()

    def prepare(self, model, graph_module):
        pass

    def convert(self, graph_module):
        pass
        
    def qconfig_convert(self, qconfig):
        pass

    def build_tracer(self):
        pass
    
    def fuse_model(self, graph_module):
        pass
    
    .....
    
  1. Algorithms

They will provide some core APIs for Quantization Loops to implement quantization pipelines. Such as calib_step,prepare,convert and so on. And Algorithms also maintain traced graphs and forward with graphs.

  1. Quantization Loops

They inherited mmengine's TrainLoop and TestLoop, adding some core quantization steps, such as calibrate,preprare,convert. There are also some special steps for some quantization algorithms, such as subgraph reconstruction.

How to trace the model automatically

Because torch.fx has its own limitations, some models' forward can not be traced when there are some special cases in forward, such as dynamic judgment.

For tracing the model automatically, we custom a CustomTracer and UntracedMethodRegistry. UntracedMethodRegistry can be used as a decorator to make decorated methods skipped by CustomTracer. What is more, methods to be skipped can be configured in our configs. Please refer to the chapter User-friendly config to learn about its usage.

So the solution is as follows.

  1. Collect these untraceable codes to a function or a method and make the rest of the pipeline traceable. In OpenMMLab 2.0, we refactored some model interfaces to adapt torch.fx preliminary.

  2. Specified these methods to be skipped in our configs.

WIP code

For more details about the implementation, please refer to the branch: https://github.com/open-mmlab/mmrazor/tree/quantize

Note:

The quantize branch is in development, modifying code will happen at any time.

Metadata

Metadata

Labels

RFCRequest for Comments

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions