Skip to content

Commit 35dac8e

Browse files
authored
Group Index Quantization Support (vllm-project#134)
Precursor to vllm-project#97 which adds support for out-of-order group quantization specified by g_idx
1 parent d75e1a6 commit 35dac8e

File tree

4 files changed

+270
-22
lines changed

4 files changed

+270
-22
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Optional
1818

1919
import torch
20+
from compressed_tensors.quantization.lifecycle.helpers import safe_permute
2021
from compressed_tensors.quantization.observers.helpers import calculate_range
2122
from compressed_tensors.quantization.quant_args import (
2223
QuantizationArgs,
@@ -45,6 +46,7 @@ def quantize(
4546
zero_point: torch.Tensor,
4647
args: QuantizationArgs,
4748
dtype: Optional[torch.dtype] = None,
49+
g_idx: Optional[torch.Tensor] = None,
4850
) -> torch.Tensor:
4951
"""
5052
Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -58,6 +60,7 @@ def quantize(
5860
:param zero_point: zero point tensor
5961
:param args: quantization args dictating how to quantize x
6062
:param dtype: optional dtype to cast the quantized output to
63+
:param g_idx: optional mapping from column index to group index
6164
:return: fake quantized tensor
6265
"""
6366
# ensure all tensors are on the same device
@@ -76,6 +79,7 @@ def quantize(
7679
dtype=dtype,
7780
do_quantize=True,
7881
do_dequantize=False,
82+
g_idx=g_idx,
7983
)
8084

8185

@@ -86,6 +90,7 @@ def dequantize(
8690
zero_point: torch.Tensor = None,
8791
args: QuantizationArgs = None,
8892
dtype: Optional[torch.dtype] = None,
93+
g_idx: Optional[torch.Tensor] = None,
8994
) -> torch.Tensor:
9095
"""
9196
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -96,6 +101,7 @@ def dequantize(
96101
:param zero_point: zero point tensor
97102
:param args: quantization args used to quantize x_q
98103
:param dtype: optional dtype to cast the dequantized output to
104+
:param g_idx: optional mapping from column index to group index
99105
:return: dequantized float tensor
100106
"""
101107
if args is None:
@@ -126,6 +132,7 @@ def dequantize(
126132
do_quantize=False,
127133
do_dequantize=True,
128134
dtype=dtype,
135+
g_idx=g_idx,
129136
)
130137

131138

@@ -135,6 +142,7 @@ def fake_quantize(
135142
scale: torch.Tensor,
136143
zero_point: torch.Tensor,
137144
args: QuantizationArgs,
145+
g_idx: Optional[torch.Tensor] = None,
138146
) -> torch.Tensor:
139147
"""
140148
Fake quantize the input tensor x by quantizing then dequantizing with
@@ -147,6 +155,7 @@ def fake_quantize(
147155
:param scale: scale tensor
148156
:param zero_point: zero point tensor
149157
:param args: quantization args dictating how to quantize x
158+
:param g_idx: optional mapping from column index to group index
150159
:return: fake quantized tensor
151160
"""
152161
return _process_quantization(
@@ -156,6 +165,7 @@ def fake_quantize(
156165
args=args,
157166
do_quantize=True,
158167
do_dequantize=True,
168+
g_idx=g_idx,
159169
)
160170

161171

@@ -164,21 +174,19 @@ def _process_quantization(
164174
x: torch.Tensor,
165175
scale: torch.Tensor,
166176
zero_point: torch.Tensor,
177+
g_idx: Optional[torch.Tensor],
167178
args: QuantizationArgs,
168179
dtype: Optional[torch.dtype] = None,
169180
do_quantize: bool = True,
170181
do_dequantize: bool = True,
171182
) -> torch.Tensor:
172-
173183
q_min, q_max = calculate_range(args, x.device)
174184
group_size = args.group_size
175185

176186
if args.strategy == QuantizationStrategy.GROUP:
177187
output_dtype = dtype if dtype is not None else x.dtype
178188
output = torch.zeros_like(x).to(output_dtype)
179-
180-
# TODO: vectorize the for loop
181-
# TODO: fix genetric assumption about the tensor size for computing group
189+
columns = output.shape[1]
182190

183191
# TODO: make validation step for inputs
184192

@@ -187,37 +195,52 @@ def _process_quantization(
187195
scale = scale.unsqueeze(1)
188196
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
189197

190-
columns = x.shape[1]
191198
if columns >= group_size:
192199
if columns % group_size != 0:
193200
raise ValueError(
194-
"tesnor column shape must be divisble "
201+
"tensor column shape must be divisble "
195202
f"by the given group_size {group_size}"
196203
)
197-
for i in range(ceil(columns / group_size)):
198-
# scale.shape should be [nchan, ndim]
199-
# sc.shape should be [nchan, 1] after unsqueeze
200-
sc = scale[:, i].view(-1, 1)
201-
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
202204

203-
idx = i * group_size
205+
# support column-order (default) quantization as well as other orderings
206+
# such as activation ordering. Below checks if g_idx has been initialized
207+
is_column_order = g_idx is None or -1 in g_idx
208+
if is_column_order:
209+
num_groups = int(ceil(columns / group_size))
210+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
211+
212+
else:
213+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
214+
group_sizes = group_sizes[torch.argsort(group_indices)]
215+
216+
perm = torch.argsort(g_idx)
217+
x = safe_permute(x, perm, dim=1)
218+
219+
# TODO: experiment with vectorizing for loop for performance
220+
end = 0
221+
for index, group_count in enumerate(group_sizes):
222+
sc = scale[:, index].view(-1, 1)
223+
zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
224+
225+
start = end
226+
end = start + group_count
204227
if do_quantize:
205-
output[:, idx : (idx + group_size)] = _quantize(
206-
x[:, idx : (idx + group_size)],
228+
output[:, start:end] = _quantize(
229+
x[:, start:end],
207230
sc,
208231
zp,
209232
q_min,
210233
q_max,
211234
args,
212235
dtype=dtype,
213236
)
237+
214238
if do_dequantize:
215-
input = (
216-
output[:, idx : (idx + group_size)]
217-
if do_quantize
218-
else x[:, idx : (idx + group_size)]
219-
)
220-
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
239+
input = output[:, start:end] if do_quantize else x[:, start:end]
240+
output[:, start:end] = _dequantize(input, sc, zp)
241+
242+
if not is_column_order:
243+
output = safe_permute(output, torch.argsort(perm), dim=1)
221244

222245
else: # covers channel, token and tensor strategies
223246
if do_quantize:
@@ -304,6 +327,8 @@ def maybe_calibrate_or_quantize(
304327
# skip quantization
305328
return value
306329

330+
g_idx = getattr(module, "weight_g_idx", None)
331+
307332
if args.dynamic:
308333
# dynamic quantization - get scale and zero point directly from observer
309334
observer = getattr(module, f"{base_name}_observer")
@@ -326,7 +351,7 @@ def maybe_calibrate_or_quantize(
326351
update_parameter_data(module, updated_scale, f"{base_name}_scale")
327352
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
328353

329-
return fake_quantize(value, scale, zero_point, args)
354+
return fake_quantize(value, scale, zero_point, args, g_idx=g_idx)
330355

331356

332357
@torch.no_grad()

src/compressed_tensors/quantization/lifecycle/helpers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
Miscelaneous helpers for the quantization lifecycle
1717
"""
1818

19+
from typing import Set, Tuple
1920

21+
import torch
2022
from torch.nn import Module
2123

2224

@@ -51,3 +53,53 @@ def enable_quantization(module: Module):
5153

5254
def disable_quantization(module: Module):
5355
module.quantization_enabled = False
56+
57+
58+
# these datatypes are missing implementations required for standard permutation
59+
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
60+
61+
62+
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
63+
"""
64+
Perform out-of-place permutation without using torch.Tensor.index_put_,
65+
whose implementation is missing for datatypes such as `torch.float8_e4m3fn`
66+
67+
:param value: tensor to permute
68+
:param perm: permutation map
69+
:param dim: dimension along which to apply permutation
70+
:return: permuted value
71+
"""
72+
dtype_tuple = (value.dtype, value.device)
73+
74+
if dtype_tuple in _EXPERIMENTAL_DTYPES:
75+
return _fallback_permute(value, perm, dim)
76+
77+
try:
78+
return value[tuple([slice(None)] * dim + [perm])]
79+
except RuntimeError:
80+
# Mark dtype as experimental if advanced indexing fails
81+
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
82+
return _fallback_permute(value, perm, dim)
83+
84+
85+
def _fallback_permute(
86+
value: torch.Tensor, perm: torch.Tensor, dim: int
87+
) -> torch.Tensor:
88+
"""
89+
Fallback permutation method for experimental dtypes.
90+
91+
:param value: tensor to permute
92+
:param perm: permutation map
93+
:param dim: dimension along which to apply permutation
94+
:return: permuted value
95+
"""
96+
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
97+
orig_slices = [slice(None)] * (dim + 1)
98+
perm_slices = [slice(None)] * (dim + 1)
99+
100+
for index, perm_index in enumerate(perm):
101+
orig_slices[dim] = index
102+
perm_slices[dim] = perm_index
103+
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
104+
105+
return value_ret

0 commit comments

Comments
 (0)