-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
[Kernel] AQ AZP 4/4: Integrate asymmetric quantization to linear method #7271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
d55e28b
4607a75
7799643
782c536
d5f64e1
c5cdfe3
9290d4e
00ba7ff
f6e5978
51e2316
1c7ed61
9eee370
1ffac4d
4a024cd
33c4269
e39b859
369f766
0dab92a
5aeeba3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import torch | ||
from torch.nn import Parameter | ||
|
||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
CompressedTensorsScheme) | ||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( | ||
|
@@ -14,12 +15,16 @@ | |
ModelWeightParameter, | ||
PerTensorScaleParameter) | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme): | ||
|
||
def __init__(self, strategy: str, is_static_input_scheme: bool): | ||
def __init__(self, strategy: str, is_static_input_scheme: bool, | ||
input_symmetric: bool): | ||
self.strategy = strategy | ||
self.is_static_input_scheme = is_static_input_scheme | ||
self.input_symmetric = input_symmetric | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
|
@@ -46,10 +51,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
requires_grad=False) | ||
# INPUT SCALE | ||
if self.is_static_input_scheme: | ||
layer.input_scale = Parameter(layer.input_scale.max(), | ||
requires_grad=False) | ||
if self.input_symmetric: | ||
layer.input_scale = Parameter(layer.input_scale.max(), | ||
requires_grad=False) | ||
else: | ||
# Static asymmetric quantization has not been tested yet. | ||
# Kernel and ops support exists and is tested, it's just the | ||
# following integration code that is untested. | ||
logger.warning( | ||
"Static asymmetric quantization currently untested") | ||
ProExpertProg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# reconstruct the ranges | ||
int8_traits = torch.iinfo(torch.int8) | ||
range_max = (layer.input_scale * | ||
(int8_traits.max - layer.input_zero_point)).max() | ||
range_min = (layer.input_scale * | ||
(int8_traits.min - layer.input_zero_point)).min() | ||
|
||
scale = (range_max - range_min) / (int8_traits.max - | ||
int8_traits.min) | ||
layer.input_scale = Parameter(scale, requires_grad=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should add an accuracy test to make sure this works as expected There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where should we add the accuracy test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is now complete with the compressed tensors quantization tests - thanks! |
||
|
||
azp = int8_traits.min - range_min / scale | ||
layer.input_zero_point = Parameter(azp, requires_grad=False) | ||
|
||
else: | ||
layer.input_scale = None | ||
layer.input_zero_point = None | ||
|
||
if not self.input_symmetric: | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# azp_adj is the AZP adjustment term, used to account for weights. | ||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md | ||
ProExpertProg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md | ||
layer.azp_adj = layer.weight.sum(dim=0, | ||
mgoin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
keepdim=True, | ||
dtype=torch.int32) | ||
else: | ||
layer.azp_adj = None | ||
|
||
def create_weights(self, layer: torch.nn.Module, | ||
output_partition_sizes: List[int], | ||
|
@@ -90,11 +128,19 @@ def create_weights(self, layer: torch.nn.Module, | |
weight_loader=weight_loader) | ||
layer.register_parameter("input_scale", input_scale) | ||
|
||
if not self.input_symmetric: | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Static asymmetric quantization has not been tested yet | ||
logger.warning( | ||
"Static asymmetric quantization currently untested") | ||
ProExpertProg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
input_zero_point = Parameter(torch.zeros(1, dtype=torch.int32)) | ||
ProExpertProg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
layer.register_parameter("input_zero_point", input_zero_point) | ||
|
||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, | ||
bias: Optional[torch.Tensor]) -> torch.Tensor: | ||
|
||
return apply_int8_linear(input=x, | ||
weight=layer.weight, | ||
weight_scale=layer.weight_scale, | ||
input_scale=layer.input_scale, | ||
input_zero_point=layer.input_zero_point, | ||
azp_adj=layer.azp_adj, | ||
bias=bias) |
Uh oh!
There was an error while loading. Please reload this page.