Skip to content

Commit 91b15d2

Browse files
authored
Add warning for non-divisible group quantization (#1401)
## Purpose ## * Test discrepancies between initialized parameters and values calculated by observers * Reveal potential issue with how qparams are initialized neuralmagic/compressed-tensors#308 * Add warning for when users attempt to quantize groups that aren't perfectly divisible ## Prerequisites ## * #1431 ## Changes ## * Added `test_observers_update` in `tests/llmcompressor/modifiers/calibration/test_observers.py` * Add a warning for attempts to quantize indivisible groups ``` Attempting to quantize a module weight whose columns (3420) are not divisible by group_size (128). This scheme is not supported by vLLM, please consider adjusting the group_size for modules with this number of columns ``` ## Testing ## * This test fails without CT changes, but succeeds with them --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7c7af39 commit 91b15d2

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

src/llmcompressor/observers/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def get_qparams(
104104
rows = observed.shape[0]
105105
columns = observed.shape[1]
106106
num_groups = int(ceil(columns / group_size))
107+
if num_groups * group_size != columns:
108+
logger.bind(log_once=True).warning(
109+
"Attempting to quantize a module weight whose columns "
110+
f"({columns}) are not divisible by group_size ({group_size}). "
111+
"This scheme is not supported by vLLM, please consider "
112+
"adjusting the group_size for modules with this number of "
113+
"columns",
114+
)
115+
107116
self._scale = torch.empty(
108117
(rows, num_groups), dtype=observed.dtype, device=observed.device
109118
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pytest
2+
import torch
3+
from compressed_tensors.quantization import (
4+
QuantizationArgs,
5+
QuantizationScheme,
6+
initialize_module_for_quantization,
7+
)
8+
9+
from llmcompressor.modifiers.quantization.calibration import initialize_observer
10+
11+
12+
@pytest.mark.parametrize(
13+
"shape,group_size,actorder",
14+
[
15+
((1, 1), None, False),
16+
((1, 1), 128, False),
17+
((1, 1), 128, True),
18+
((64, 64), None, False),
19+
((64, 64), 128, False),
20+
((64, 64), 128, True),
21+
((1792, 4096), None, False),
22+
((1792, 4096), 128, False),
23+
((1792, 4096), 128, True),
24+
((3420, 64), None, False),
25+
((3420, 64), 128, False),
26+
((3420, 64), 128, True),
27+
],
28+
)
29+
def test_observers_update(shape, group_size, actorder):
30+
module = torch.nn.Linear(*shape)
31+
scheme = QuantizationScheme(
32+
targets=["Linear"],
33+
weights=QuantizationArgs(group_size=group_size, actorder=actorder),
34+
input_activations=QuantizationArgs(),
35+
output_activations=QuantizationArgs(),
36+
)
37+
38+
input = torch.empty(module.in_features, dtype=module.weight.dtype)
39+
output = torch.empty(module.out_features, dtype=module.weight.dtype)
40+
41+
initialize_module_for_quantization(module, scheme)
42+
initialize_observer(module, "weight")
43+
initialize_observer(module, "input")
44+
initialize_observer(module, "output")
45+
46+
for location, value in (
47+
("weight", module.weight),
48+
("input", input),
49+
("output", output),
50+
):
51+
observer = getattr(module, f"{location}_observer")
52+
g_idx = getattr(module, "g_idx", None)
53+
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
54+
55+
assert_alike(updated_scale, getattr(module, f"{location}_scale"))
56+
assert_alike(updated_zero_point, getattr(module, f"{location}_zero_point"))
57+
58+
59+
def assert_alike(a, b):
60+
assert a.dtype == b.dtype
61+
assert a.shape == b.shape

0 commit comments

Comments
 (0)