Skip to content

Commit 4607a75

Browse files
committed
PR comments
1 parent d55e28b commit 4607a75

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,33 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
4848
requires_grad=False)
4949
# INPUT SCALE
5050
if self.is_static_input_scheme:
51-
layer.input_scale = Parameter(layer.input_scale.max(),
52-
requires_grad=False)
53-
if not self.input_symmetric:
54-
layer.input_zero_point = Parameter(layer.input_zero_point,
55-
requires_grad=False)
51+
if self.input_symmetric:
52+
layer.input_scale = Parameter(layer.input_scale.max(),
53+
requires_grad=False)
5654
else:
57-
layer.input_zero_point = None
55+
raise NotImplementedError(
56+
"static input asymmetric quantization not supported yet")
57+
# reconstruct the ranges
58+
int8_traits = torch.iinfo(torch.int8)
59+
range_max = (layer.input_scale *
60+
(int8_traits.max - layer.input_zero_point)).max()
61+
range_min = (layer.input_scale *
62+
(int8_traits.min - layer.input_zero_point)).min()
63+
64+
scale = (range_max - range_min) / (int8_traits.max -
65+
int8_traits.min)
66+
layer.input_scale = Parameter(scale, requires_grad=False)
67+
68+
azp = int8_traits.min - range_min / scale
69+
layer.input_zero_point = Parameter(azp, requires_grad=False)
70+
5871
else:
5972
layer.input_scale = None
6073
layer.input_zero_point = None
6174

6275
if not self.input_symmetric:
76+
# azp_adj is the AZP adjustment term, used to account for weights.
77+
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
6378
layer.azp_adj = layer.weight.sum(dim=0,
6479
keepdim=True,
6580
dtype=torch.int32)
@@ -108,7 +123,7 @@ def create_weights(self, layer: torch.nn.Module,
108123
if not self.input_symmetric:
109124
raise NotImplementedError(
110125
"static input asymmetric quantization not supported yet")
111-
input_zero_point = Parameter(torch.zeros(1, dtype=torch.int8))
126+
input_zero_point = Parameter(torch.zeros(1, dtype=torch.int32))
112127
layer.register_parameter("input_zero_point", input_zero_point)
113128

114129
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,

0 commit comments

Comments
 (0)