Skip to content

Commit 611f152

Browse files
authored
group size speedups + fixes (vllm-project#51)
1 parent 964276d commit 611f152

File tree

4 files changed

+27
-13
lines changed

4 files changed

+27
-13
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,8 @@ def fake_quantize(
111111
for i in range(ceil(columns / group_size)):
112112
# scale.shape should be [nchan, ndim]
113113
# sc.shape should be [nchan, 1] after unsqueeze
114-
115-
sc = scale[:, i].unsqueeze(1)
116-
zp = zero_point[:, i].unsqueeze(1)
114+
sc = scale[:, i].view(-1, 1)
115+
zp = zero_point[:, i].view(-1, 1)
117116

118117
idx = i * group_size
119118
Q = quantize(x[:, idx : (idx + group_size)], sc, zp, args)

src/compressed_tensors/quantization/observers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self, quantization_args: QuantizationArgs):
4040
self._scale = None
4141
self._zero_point = None
4242

43+
@torch.no_grad()
4344
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
4445
"""
4546
maps directly to get_qparams
@@ -91,9 +92,8 @@ def get_qparams(
9192
)
9293
scales.append(scale)
9394
zero_points.append(zero_point)
94-
95-
self._scale = torch.stack(scales, dim=1)
96-
self._zero_point = torch.stack(zero_points, dim=1)
95+
self._scale = torch.stack(scales, dim=1, out=self._scale)
96+
self._zero_point = torch.stack(zero_points, dim=1, out=self._zero_point)
9797

9898
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
9999
# assume observed is transposed, because its the output, hence use dim 0

src/compressed_tensors/quantization/observers/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def calculate_qparams(
4141
bit_min = -(bit_range + 1) / 2
4242
bit_max = bit_min + bit_range
4343
if quantization_args.symmetric:
44-
zero_points = torch.tensor(0, device=device).to(torch.int8)
4544
max_val_pos = torch.max(-min_vals, max_vals)
4645
scales = max_val_pos / (float(bit_range) / 2)
4746
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
47+
zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
4848
else:
4949
scales = (max_vals - min_vals) / float(bit_range)
5050
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)

src/compressed_tensors/quantization/observers/min_max.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Tuple
15+
from typing import Optional, Tuple
1616

1717
import torch
1818
from compressed_tensors.quantization.observers.base import Observer
@@ -36,22 +36,33 @@ def __init__(
3636
):
3737
super().__init__(quantization_args=quantization_args)
3838

39-
self.min_val = float("inf")
40-
self.max_val = -float("inf")
39+
self.min_val = None
40+
self.max_val = None
4141
self.averaging_constant = averaging_constant
4242

43-
def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
43+
def calculate_qparams(
44+
self,
45+
observed: Tensor,
46+
reduce_dims: Optional[Tuple[int]] = None,
47+
) -> Tuple[FloatTensor, IntTensor]:
4448
"""
4549
Updates the observed min and max using a moving average smoothed by the
4650
averaging_constant
4751
4852
:param observed: observed tensor to calculate quantization parameters for
53+
:param reduce_dims: optional tuple of dimensions to reduce along,
54+
returned scale and zero point will be shaped (1,) along the
55+
reduced dimensions
4956
:return: tuple of scale and zero point derived from the observed tensor
5057
"""
5158

52-
min_val, max_val = torch.aminmax(observed)
59+
if not reduce_dims:
60+
min_val, max_val = torch.aminmax(observed)
61+
else:
62+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
63+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
5364

54-
if self.min_val == float("inf") and self.max_val == float("-inf"):
65+
if self.min_val is None and self.max_val is None:
5566
self.min_val = min_val
5667
self.max_val = max_val
5768
else:
@@ -63,3 +74,7 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
6374
)
6475

6576
return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
77+
78+
def get_qparams_along_dim(self, observed, dim: int):
79+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
80+
return self.calculate_qparams(observed, reduce_dims=reduce_dims)

0 commit comments

Comments
 (0)