Skip to content

Commit d914b73

Browse files
Sara Adkinshorheynmbfineran
authored
Group and Channelwise Compression Support (vllm-project#52)
* group size * add logic in base observer * group size full lifecycle run * before vectorize the for loop * comments, todo add channelwise * chan wise impl * comments * fix channel wise * comments, validators * fix typo * tensor return error fix * fix sparseml-side of code and add per channel * pyndatic defaults * token wise quant * Update src/compressed_tensors/quantization/quant_args.py Co-authored-by: Benjamin Fineran <[email protected]> * comments' * update dim * shape consistency * Update src/compressed_tensors/quantization/lifecycle/forward.py Co-authored-by: Benjamin Fineran <[email protected]> * comments * pass test_quant_args * fix channelwise * new tests, some fail * WIP * group compression * fix output type on decompress * fix channelwise * revert * more tests --------- Co-authored-by: George Ohashi <[email protected]> Co-authored-by: Benjamin Fineran <[email protected]>
1 parent 611f152 commit d914b73

File tree

8 files changed

+422
-51
lines changed

8 files changed

+422
-51
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 145 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@
2626
from torch.nn import Module
2727

2828

29-
__all__ = ["wrap_module_forward_quantized", "maybe_calibrate_or_quantize"]
29+
__all__ = [
30+
"quantize",
31+
"dequantize",
32+
"fake_quantize",
33+
"wrap_module_forward_quantized",
34+
"maybe_calibrate_or_quantize",
35+
]
3036

3137

3238
@torch.no_grad()
@@ -37,29 +43,66 @@ def quantize(
3743
args: QuantizationArgs,
3844
dtype: Optional[torch.dtype] = None,
3945
) -> torch.Tensor:
40-
bit_range = 2**args.num_bits
41-
q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
42-
q_min = torch.tensor(-bit_range / 2, device=x.device)
46+
"""
47+
Quantize the input tensor x using the QuantizationStrategy specified in args.
48+
Quantization can be done per tensor, channel, token or group. For group
49+
quantization, the group_size must be divisible by the column size. The input scale
50+
and zero_points are reshaped to support vectorization (Assumes 1 is the
51+
channel dimension)
4352
44-
quantized_value = torch.clamp(
45-
torch.round(x / scale + zero_point),
46-
q_min,
47-
q_max,
53+
:param x: Input tensor
54+
:param scale: scale tensor
55+
:param zero_point: zero point tensor
56+
:param args: quantization args dictating how to quantize x
57+
:param dtype: optional dtype to cast the quantized output to
58+
:return: fake quantized tensor
59+
"""
60+
return _process_quantization(
61+
x=x,
62+
scale=scale,
63+
zero_point=zero_point,
64+
args=args,
65+
dtype=dtype,
66+
do_quantize=True,
67+
do_dequantize=False,
4868
)
4969

50-
if dtype is not None:
51-
quantized_value = quantized_value.to(dtype)
52-
53-
return quantized_value
54-
5570

5671
@torch.no_grad()
5772
def dequantize(
5873
x_q: torch.Tensor,
5974
scale: torch.Tensor,
6075
zero_point: torch.Tensor,
76+
args: QuantizationArgs = None,
6177
) -> torch.Tensor:
62-
return (x_q - zero_point) * scale
78+
"""
79+
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
80+
args is not provided, the strategy will be inferred.
81+
82+
:param x: quantized input tensor
83+
:param scale: scale tensor
84+
:param zero_point: zero point tensor
85+
:param args: quantization args used to quantize x_q
86+
:return: dequantized float tensor
87+
"""
88+
if args is None:
89+
if scale.ndim == 0:
90+
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
91+
elif scale.ndim == 2:
92+
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
93+
elif scale.ndim == 3:
94+
group_size = int(x_q.shape[1] / scale.shape[1])
95+
args = QuantizationArgs(
96+
strategy=QuantizationStrategy.GROUP, group_size=group_size
97+
)
98+
return _process_quantization(
99+
x=x_q,
100+
scale=scale,
101+
zero_point=zero_point,
102+
args=args,
103+
do_quantize=False,
104+
do_dequantize=True,
105+
)
63106

64107

65108
@torch.no_grad()
@@ -70,26 +113,51 @@ def fake_quantize(
70113
args: QuantizationArgs,
71114
) -> torch.Tensor:
72115
"""
73-
Fake quantize the input tensor x depending on the group_size.
74-
if group_size is greater than 0, then q/dq by groups. The groups
75-
must be divisible by the column size
76-
if group_size is -1, then channel wise q/dq. THe input scale and
77-
zero_points are reshaped to support vectorization (Assumes 1 is
78-
the channel dimension)
116+
Fake quantize the input tensor x by quantizing then dequantizing with
117+
the QuantizationStrategy specified in args. Quantization can be done per tensor,
118+
channel, token or group. For group quantization, the group_size must be divisible
119+
by the column size. The input scale and zero_points are reshaped to support
120+
vectorization (Assumes 1 is the channel dimension)
79121
80122
:param x: Input tensor
81123
:param scale: scale tensor
82124
:param zero_point: zero point tensor
83-
:param args: quantization args that contain group_size info
125+
:param args: quantization args dictating how to quantize x
84126
:return: fake quantized tensor
85-
86127
"""
128+
return _process_quantization(
129+
x=x,
130+
scale=scale,
131+
zero_point=zero_point,
132+
args=args,
133+
do_quantize=True,
134+
do_dequantize=True,
135+
)
136+
137+
138+
@torch.no_grad()
139+
def _process_quantization(
140+
x: torch.Tensor,
141+
scale: torch.Tensor,
142+
zero_point: torch.Tensor,
143+
args: QuantizationArgs,
144+
dtype: Optional[torch.dtype] = None,
145+
do_quantize: bool = True,
146+
do_dequantize: bool = True,
147+
) -> torch.Tensor:
148+
bit_range = 2**args.num_bits
149+
q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
150+
q_min = torch.tensor(-bit_range / 2, device=x.device)
87151
group_size = args.group_size
88152

89153
# group
90154
if args.strategy == QuantizationStrategy.GROUP:
91155

92-
DQ = torch.zeros_like(x)
156+
if do_dequantize: # if dequantizing the output should be a fp type
157+
output = torch.zeros_like(x, dtype=scale.dtype)
158+
else:
159+
output_dtype = dtype if dtype is not None else x.dtype
160+
output = torch.zeros_like(x, dtype=output_dtype)
93161

94162
# TODO: vectorize the for loop
95163
# TODO: fix genetric assumption about the tensor size for computing group
@@ -115,18 +183,24 @@ def fake_quantize(
115183
zp = zero_point[:, i].view(-1, 1)
116184

117185
idx = i * group_size
118-
Q = quantize(x[:, idx : (idx + group_size)], sc, zp, args)
119-
DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)
186+
if do_quantize:
187+
output[:, idx : (idx + group_size)] = _quantize(
188+
x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
189+
)
190+
if do_dequantize:
191+
input = (
192+
output[:, idx : (idx + group_size)]
193+
if do_quantize
194+
else x[:, idx : (idx + group_size)]
195+
)
196+
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
120197

121198
# channel-wise
122199
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
123-
# before: scale shape = [channel_size]
124-
# after: scale shape = [1, channel_size]
125-
scale = scale.unsqueeze(0)
126-
zero_point = zero_point.unsqueeze(0)
127-
128-
Q = quantize(x, scale, zero_point, args)
129-
DQ = dequantize(Q, scale, zero_point)
200+
if do_quantize:
201+
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
202+
if do_dequantize:
203+
output = _dequantize(output if do_quantize else x, scale, zero_point)
130204

131205
# per-token
132206
elif args.strategy == QuantizationStrategy.TOKEN:
@@ -138,14 +212,18 @@ def fake_quantize(
138212
scale = scale.unsqueeze(1)
139213
zero_point = zero_point.unsqueeze(1)
140214

141-
Q = quantize(x, scale, zero_point, args)
142-
DQ = dequantize(Q, scale, zero_point)
215+
if do_quantize:
216+
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
217+
if do_dequantize:
218+
output = _dequantize(output if do_quantize else x, scale, zero_point)
143219

144220
else:
145-
Q = quantize(x, scale, zero_point, args)
146-
DQ = dequantize(Q, scale, zero_point)
221+
if do_quantize:
222+
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
223+
if do_dequantize:
224+
output = _dequantize(output if do_quantize else x, scale, zero_point)
147225

148-
return DQ
226+
return output
149227

150228

151229
def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
@@ -223,3 +301,33 @@ def maybe_calibrate_or_quantize(
223301
scale.data = updated_scale.to(device)
224302
zero_point.data = updated_zero_point.to(device)
225303
return fake_quantize(value, scale, zero_point, args)
304+
305+
306+
@torch.no_grad()
307+
def _quantize(
308+
x: torch.Tensor,
309+
scale: torch.Tensor,
310+
zero_point: torch.Tensor,
311+
q_min: torch.Tensor,
312+
q_max: torch.Tensor,
313+
dtype: Optional[torch.dtype] = None,
314+
) -> torch.Tensor:
315+
quantized_value = torch.clamp(
316+
torch.round(x / scale + zero_point),
317+
q_min,
318+
q_max,
319+
)
320+
321+
if dtype is not None:
322+
quantized_value = quantized_value.to(dtype)
323+
324+
return quantized_value
325+
326+
327+
@torch.no_grad()
328+
def _dequantize(
329+
x_q: torch.Tensor,
330+
scale: torch.Tensor,
331+
zero_point: torch.Tensor,
332+
) -> torch.Tensor:
333+
return (x_q - zero_point) * scale

tests/test_compressors/__init__.py

Whitespace-only changes.
File renamed without changes.

tests/test_int_quant.py renamed to tests/test_compressors/test_int_quant.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,25 @@
1414

1515
import shutil
1616

17+
import pytest
1718
import torch
1819
from compressed_tensors import IntQuantizationCompressor
1920
from compressed_tensors.quantization import (
2021
QuantizationArgs,
2122
QuantizationConfig,
2223
QuantizationScheme,
24+
QuantizationStrategy,
2325
)
2426
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
2527
from safetensors.torch import save_file
2628

2729

28-
def get_dummy_quant_config():
30+
def get_dummy_quant_config(strategy, group_size=None):
2931
config_groups = {
30-
"group_1": QuantizationScheme(targets=["Linear"], weights=QuantizationArgs()),
32+
"group_1": QuantizationScheme(
33+
targets=["Linear"],
34+
weights=QuantizationArgs(strategy=strategy, group_size=group_size),
35+
),
3136
}
3237
ignore = ["lm_head"]
3338
quant_config = QuantizationConfig(
@@ -38,13 +43,31 @@ def get_dummy_quant_config():
3843
return quant_config
3944

4045

41-
def test_quant_format():
46+
@pytest.mark.parametrize(
47+
"strategy,group_size,sc,zp",
48+
[
49+
[QuantizationStrategy.TENSOR, None, 0.01, 0],
50+
[
51+
QuantizationStrategy.GROUP,
52+
128,
53+
torch.rand((512, 8, 1)) * 0.01,
54+
torch.zeros((512, 8, 1), dtype=torch.int8),
55+
],
56+
[
57+
QuantizationStrategy.CHANNEL,
58+
128,
59+
torch.rand((512, 1)) * 0.01,
60+
torch.zeros((512, 1), dtype=torch.int8),
61+
],
62+
],
63+
)
64+
def test_quant_format(strategy, group_size, sc, zp):
4265
dense_state_dict = {
4366
"dummy.weight": torch.rand((512, 1024)),
44-
"dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32),
45-
"dummy.weight_zero_point": torch.tensor(0, dtype=torch.int32),
67+
"dummy.weight_scale": torch.tensor(sc, dtype=torch.float32),
68+
"dummy.weight_zero_point": torch.tensor(zp, dtype=torch.int32),
4669
}
47-
quant_config = get_dummy_quant_config()
70+
quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
4871

4972
compressor = IntQuantizationCompressor(config=quant_config)
5073
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
@@ -61,16 +84,34 @@ def test_quant_format():
6184
assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32
6285

6386

64-
def test_reload_match(tmp_path):
87+
@pytest.mark.parametrize(
88+
"strategy,group_size,sc,zp",
89+
[
90+
[QuantizationStrategy.TENSOR, None, 0.01, 0],
91+
[
92+
QuantizationStrategy.GROUP,
93+
128,
94+
torch.rand((300, 8, 1)) * 0.01,
95+
torch.zeros((300, 8, 1), dtype=torch.int8),
96+
],
97+
[
98+
QuantizationStrategy.CHANNEL,
99+
128,
100+
torch.rand((300, 1)) * 0.01,
101+
torch.zeros((300, 1), dtype=torch.int8),
102+
],
103+
],
104+
)
105+
def test_reload_match(strategy, group_size, sc, zp, tmp_path):
65106
dense_state_dict = {
66-
"dummy.weight": torch.rand((511, 350)),
67-
"dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32),
68-
"dummy.weight_zero_point": torch.tensor(0, dtype=torch.int32),
69-
"dummy2.weight": torch.rand((128, 280)),
70-
"dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32),
71-
"dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int32),
107+
"dummy.weight": torch.rand((300, 1024)),
108+
"dummy.weight_scale": torch.tensor(sc, dtype=torch.float32),
109+
"dummy.weight_zero_point": torch.tensor(zp, dtype=torch.int32),
110+
"dummy2.weight": torch.rand((300, 1024)),
111+
"dummy2.weight_scale": torch.tensor(sc, dtype=torch.float32),
112+
"dummy2.weight_zero_point": torch.tensor(zp, dtype=torch.int32),
72113
}
73-
quant_config = get_dummy_quant_config()
114+
quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
74115

75116
compressor = IntQuantizationCompressor(config=quant_config)
76117
quantized_modules_to_args = {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)