17
17
from typing import Optional
18
18
19
19
import torch
20
+ from compressed_tensors .quantization .lifecycle .helpers import safe_permute
20
21
from compressed_tensors .quantization .observers .helpers import calculate_range
21
22
from compressed_tensors .quantization .quant_args import (
22
23
QuantizationArgs ,
@@ -45,6 +46,7 @@ def quantize(
45
46
zero_point : torch .Tensor ,
46
47
args : QuantizationArgs ,
47
48
dtype : Optional [torch .dtype ] = None ,
49
+ g_idx : Optional [torch .Tensor ] = None ,
48
50
) -> torch .Tensor :
49
51
"""
50
52
Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -58,6 +60,7 @@ def quantize(
58
60
:param zero_point: zero point tensor
59
61
:param args: quantization args dictating how to quantize x
60
62
:param dtype: optional dtype to cast the quantized output to
63
+ :param g_idx: optional mapping from column index to group index
61
64
:return: fake quantized tensor
62
65
"""
63
66
# ensure all tensors are on the same device
@@ -76,6 +79,7 @@ def quantize(
76
79
dtype = dtype ,
77
80
do_quantize = True ,
78
81
do_dequantize = False ,
82
+ g_idx = g_idx ,
79
83
)
80
84
81
85
@@ -86,6 +90,7 @@ def dequantize(
86
90
zero_point : torch .Tensor = None ,
87
91
args : QuantizationArgs = None ,
88
92
dtype : Optional [torch .dtype ] = None ,
93
+ g_idx : Optional [torch .Tensor ] = None ,
89
94
) -> torch .Tensor :
90
95
"""
91
96
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -96,6 +101,7 @@ def dequantize(
96
101
:param zero_point: zero point tensor
97
102
:param args: quantization args used to quantize x_q
98
103
:param dtype: optional dtype to cast the dequantized output to
104
+ :param g_idx: optional mapping from column index to group index
99
105
:return: dequantized float tensor
100
106
"""
101
107
if args is None :
@@ -126,6 +132,7 @@ def dequantize(
126
132
do_quantize = False ,
127
133
do_dequantize = True ,
128
134
dtype = dtype ,
135
+ g_idx = g_idx ,
129
136
)
130
137
131
138
@@ -135,6 +142,7 @@ def fake_quantize(
135
142
scale : torch .Tensor ,
136
143
zero_point : torch .Tensor ,
137
144
args : QuantizationArgs ,
145
+ g_idx : Optional [torch .Tensor ] = None ,
138
146
) -> torch .Tensor :
139
147
"""
140
148
Fake quantize the input tensor x by quantizing then dequantizing with
@@ -147,6 +155,7 @@ def fake_quantize(
147
155
:param scale: scale tensor
148
156
:param zero_point: zero point tensor
149
157
:param args: quantization args dictating how to quantize x
158
+ :param g_idx: optional mapping from column index to group index
150
159
:return: fake quantized tensor
151
160
"""
152
161
return _process_quantization (
@@ -156,6 +165,7 @@ def fake_quantize(
156
165
args = args ,
157
166
do_quantize = True ,
158
167
do_dequantize = True ,
168
+ g_idx = g_idx ,
159
169
)
160
170
161
171
@@ -164,21 +174,19 @@ def _process_quantization(
164
174
x : torch .Tensor ,
165
175
scale : torch .Tensor ,
166
176
zero_point : torch .Tensor ,
177
+ g_idx : Optional [torch .Tensor ],
167
178
args : QuantizationArgs ,
168
179
dtype : Optional [torch .dtype ] = None ,
169
180
do_quantize : bool = True ,
170
181
do_dequantize : bool = True ,
171
182
) -> torch .Tensor :
172
-
173
183
q_min , q_max = calculate_range (args , x .device )
174
184
group_size = args .group_size
175
185
176
186
if args .strategy == QuantizationStrategy .GROUP :
177
187
output_dtype = dtype if dtype is not None else x .dtype
178
188
output = torch .zeros_like (x ).to (output_dtype )
179
-
180
- # TODO: vectorize the for loop
181
- # TODO: fix genetric assumption about the tensor size for computing group
189
+ columns = output .shape [1 ]
182
190
183
191
# TODO: make validation step for inputs
184
192
@@ -187,37 +195,52 @@ def _process_quantization(
187
195
scale = scale .unsqueeze (1 )
188
196
zero_point = zero_point .unsqueeze (1 ) if zero_point is not None else None
189
197
190
- columns = x .shape [1 ]
191
198
if columns >= group_size :
192
199
if columns % group_size != 0 :
193
200
raise ValueError (
194
- "tesnor column shape must be divisble "
201
+ "tensor column shape must be divisble "
195
202
f"by the given group_size { group_size } "
196
203
)
197
- for i in range (ceil (columns / group_size )):
198
- # scale.shape should be [nchan, ndim]
199
- # sc.shape should be [nchan, 1] after unsqueeze
200
- sc = scale [:, i ].view (- 1 , 1 )
201
- zp = zero_point [:, i ].view (- 1 , 1 ) if zero_point is not None else None
202
204
203
- idx = i * group_size
205
+ # support column-order (default) quantization as well as other orderings
206
+ # such as activation ordering. Below checks if g_idx has been initialized
207
+ is_column_order = g_idx is None or - 1 in g_idx
208
+ if is_column_order :
209
+ num_groups = int (ceil (columns / group_size ))
210
+ group_sizes = torch .full ((num_groups ,), group_size , dtype = torch .int )
211
+
212
+ else :
213
+ group_indices , group_sizes = torch .unique (g_idx , return_counts = True )
214
+ group_sizes = group_sizes [torch .argsort (group_indices )]
215
+
216
+ perm = torch .argsort (g_idx )
217
+ x = safe_permute (x , perm , dim = 1 )
218
+
219
+ # TODO: experiment with vectorizing for loop for performance
220
+ end = 0
221
+ for index , group_count in enumerate (group_sizes ):
222
+ sc = scale [:, index ].view (- 1 , 1 )
223
+ zp = zero_point [:, index ].view (- 1 , 1 ) if zero_point is not None else None
224
+
225
+ start = end
226
+ end = start + group_count
204
227
if do_quantize :
205
- output [:, idx : ( idx + group_size ) ] = _quantize (
206
- x [:, idx : ( idx + group_size ) ],
228
+ output [:, start : end ] = _quantize (
229
+ x [:, start : end ],
207
230
sc ,
208
231
zp ,
209
232
q_min ,
210
233
q_max ,
211
234
args ,
212
235
dtype = dtype ,
213
236
)
237
+
214
238
if do_dequantize :
215
- input = (
216
- output [:, idx : (idx + group_size )]
217
- if do_quantize
218
- else x [:, idx : (idx + group_size )]
219
- )
220
- output [:, idx : (idx + group_size )] = _dequantize (input , sc , zp )
239
+ input = output [:, start :end ] if do_quantize else x [:, start :end ]
240
+ output [:, start :end ] = _dequantize (input , sc , zp )
241
+
242
+ if not is_column_order :
243
+ output = safe_permute (output , torch .argsort (perm ), dim = 1 )
221
244
222
245
else : # covers channel, token and tensor strategies
223
246
if do_quantize :
@@ -304,6 +327,8 @@ def maybe_calibrate_or_quantize(
304
327
# skip quantization
305
328
return value
306
329
330
+ g_idx = getattr (module , "weight_g_idx" , None )
331
+
307
332
if args .dynamic :
308
333
# dynamic quantization - get scale and zero point directly from observer
309
334
observer = getattr (module , f"{ base_name } _observer" )
@@ -326,7 +351,7 @@ def maybe_calibrate_or_quantize(
326
351
update_parameter_data (module , updated_scale , f"{ base_name } _scale" )
327
352
update_parameter_data (module , updated_zero_point , f"{ base_name } _zero_point" )
328
353
329
- return fake_quantize (value , scale , zero_point , args )
354
+ return fake_quantize (value , scale , zero_point , args , g_idx = g_idx )
330
355
331
356
332
357
@torch .no_grad ()
0 commit comments