Skip to content

Commit 0d9001f

Browse files
author
Sara Adkins
committed
Merge branch 'main' into sa/fp8
2 parents 6981d4e + 01bcb85 commit 0d9001f

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,17 @@ def dequantize(
9494
if scale.ndim == 0:
9595
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
9696
elif scale.ndim == 2:
97-
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
98-
elif scale.ndim == 3:
99-
group_size = int(x_q.shape[1] / scale.shape[1])
100-
args = QuantizationArgs(
101-
strategy=QuantizationStrategy.GROUP, group_size=group_size
97+
if scale.shape[1] == 1:
98+
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
99+
else:
100+
group_size = int(x_q.shape[1] / scale.shape[1])
101+
args = QuantizationArgs(
102+
strategy=QuantizationStrategy.GROUP, group_size=group_size
103+
)
104+
else:
105+
raise ValueError(
106+
f"Could not infer a quantization strategy from scale with {scale.ndim} "
107+
"dimmensions. Expected 0 or 2 dimmensions."
102108
)
103109
return _process_quantization(
104110
x=x_q,
@@ -155,14 +161,12 @@ def _process_quantization(
155161
group_size = args.group_size
156162

157163
if args.strategy == QuantizationStrategy.GROUP:
158-
159-
if do_dequantize:
160-
# if dequantizing the output should match the original weight dtype,
161-
# which is the same as the scale's
164+
if do_dequantize and not do_quantize:
165+
# if dequantizing a quantized type infer the output type from the scale
162166
output = torch.zeros_like(x, dtype=scale.dtype)
163167
else:
164-
# outputting a quantized output, use the dtype passed in as a kwarg if its
165-
# specified, otherwise default to the input type
168+
# use the dtype passed in as a kwarg if its specified, otherwise default
169+
# to the input type
166170
output_dtype = dtype if dtype is not None else x.dtype
167171
if output_dtype is FP8_DTYPE:
168172
# zeros_like doesn't support fp8 types directly, workaround

tests/test_compressors/test_int_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
9797
[
9898
QuantizationStrategy.GROUP,
9999
128,
100-
torch.rand((300, 8, 1)) * 0.01,
101-
torch.zeros((300, 8, 1), dtype=torch.int8),
100+
torch.rand((300, 8)) * 0.01,
101+
torch.zeros((300, 8), dtype=torch.int8),
102102
],
103103
[
104104
QuantizationStrategy.CHANNEL,

0 commit comments

Comments
 (0)