26
26
from torch .nn import Module
27
27
28
28
29
- __all__ = ["wrap_module_forward_quantized" , "maybe_calibrate_or_quantize" ]
29
+ __all__ = [
30
+ "quantize" ,
31
+ "dequantize" ,
32
+ "fake_quantize" ,
33
+ "wrap_module_forward_quantized" ,
34
+ "maybe_calibrate_or_quantize" ,
35
+ ]
30
36
31
37
32
38
@torch .no_grad ()
@@ -37,29 +43,66 @@ def quantize(
37
43
args : QuantizationArgs ,
38
44
dtype : Optional [torch .dtype ] = None ,
39
45
) -> torch .Tensor :
40
- bit_range = 2 ** args .num_bits
41
- q_max = torch .tensor (bit_range / 2 - 1 , device = x .device )
42
- q_min = torch .tensor (- bit_range / 2 , device = x .device )
46
+ """
47
+ Quantize the input tensor x using the QuantizationStrategy specified in args.
48
+ Quantization can be done per tensor, channel, token or group. For group
49
+ quantization, the group_size must be divisible by the column size. The input scale
50
+ and zero_points are reshaped to support vectorization (Assumes 1 is the
51
+ channel dimension)
43
52
44
- quantized_value = torch .clamp (
45
- torch .round (x / scale + zero_point ),
46
- q_min ,
47
- q_max ,
53
+ :param x: Input tensor
54
+ :param scale: scale tensor
55
+ :param zero_point: zero point tensor
56
+ :param args: quantization args dictating how to quantize x
57
+ :param dtype: optional dtype to cast the quantized output to
58
+ :return: fake quantized tensor
59
+ """
60
+ return _process_quantization (
61
+ x = x ,
62
+ scale = scale ,
63
+ zero_point = zero_point ,
64
+ args = args ,
65
+ dtype = dtype ,
66
+ do_quantize = True ,
67
+ do_dequantize = False ,
48
68
)
49
69
50
- if dtype is not None :
51
- quantized_value = quantized_value .to (dtype )
52
-
53
- return quantized_value
54
-
55
70
56
71
@torch .no_grad ()
57
72
def dequantize (
58
73
x_q : torch .Tensor ,
59
74
scale : torch .Tensor ,
60
75
zero_point : torch .Tensor ,
76
+ args : QuantizationArgs = None ,
61
77
) -> torch .Tensor :
62
- return (x_q - zero_point ) * scale
78
+ """
79
+ Dequantize a quantized input tensor x_q based on the strategy specified in args. If
80
+ args is not provided, the strategy will be inferred.
81
+
82
+ :param x: quantized input tensor
83
+ :param scale: scale tensor
84
+ :param zero_point: zero point tensor
85
+ :param args: quantization args used to quantize x_q
86
+ :return: dequantized float tensor
87
+ """
88
+ if args is None :
89
+ if scale .ndim == 0 :
90
+ args = QuantizationArgs (strategy = QuantizationStrategy .TENSOR )
91
+ elif scale .ndim == 2 :
92
+ args = QuantizationArgs (strategy = QuantizationStrategy .CHANNEL )
93
+ elif scale .ndim == 3 :
94
+ group_size = int (x_q .shape [1 ] / scale .shape [1 ])
95
+ args = QuantizationArgs (
96
+ strategy = QuantizationStrategy .GROUP , group_size = group_size
97
+ )
98
+ return _process_quantization (
99
+ x = x_q ,
100
+ scale = scale ,
101
+ zero_point = zero_point ,
102
+ args = args ,
103
+ do_quantize = False ,
104
+ do_dequantize = True ,
105
+ )
63
106
64
107
65
108
@torch .no_grad ()
@@ -70,26 +113,51 @@ def fake_quantize(
70
113
args : QuantizationArgs ,
71
114
) -> torch .Tensor :
72
115
"""
73
- Fake quantize the input tensor x depending on the group_size.
74
- if group_size is greater than 0, then q/dq by groups. The groups
75
- must be divisible by the column size
76
- if group_size is -1, then channel wise q/dq. THe input scale and
77
- zero_points are reshaped to support vectorization (Assumes 1 is
78
- the channel dimension)
116
+ Fake quantize the input tensor x by quantizing then dequantizing with
117
+ the QuantizationStrategy specified in args. Quantization can be done per tensor,
118
+ channel, token or group. For group quantization, the group_size must be divisible
119
+ by the column size. The input scale and zero_points are reshaped to support
120
+ vectorization (Assumes 1 is the channel dimension)
79
121
80
122
:param x: Input tensor
81
123
:param scale: scale tensor
82
124
:param zero_point: zero point tensor
83
- :param args: quantization args that contain group_size info
125
+ :param args: quantization args dictating how to quantize x
84
126
:return: fake quantized tensor
85
-
86
127
"""
128
+ return _process_quantization (
129
+ x = x ,
130
+ scale = scale ,
131
+ zero_point = zero_point ,
132
+ args = args ,
133
+ do_quantize = True ,
134
+ do_dequantize = True ,
135
+ )
136
+
137
+
138
+ @torch .no_grad ()
139
+ def _process_quantization (
140
+ x : torch .Tensor ,
141
+ scale : torch .Tensor ,
142
+ zero_point : torch .Tensor ,
143
+ args : QuantizationArgs ,
144
+ dtype : Optional [torch .dtype ] = None ,
145
+ do_quantize : bool = True ,
146
+ do_dequantize : bool = True ,
147
+ ) -> torch .Tensor :
148
+ bit_range = 2 ** args .num_bits
149
+ q_max = torch .tensor (bit_range / 2 - 1 , device = x .device )
150
+ q_min = torch .tensor (- bit_range / 2 , device = x .device )
87
151
group_size = args .group_size
88
152
89
153
# group
90
154
if args .strategy == QuantizationStrategy .GROUP :
91
155
92
- DQ = torch .zeros_like (x )
156
+ if do_dequantize : # if dequantizing the output should be a fp type
157
+ output = torch .zeros_like (x , dtype = scale .dtype )
158
+ else :
159
+ output_dtype = dtype if dtype is not None else x .dtype
160
+ output = torch .zeros_like (x , dtype = output_dtype )
93
161
94
162
# TODO: vectorize the for loop
95
163
# TODO: fix genetric assumption about the tensor size for computing group
@@ -115,18 +183,24 @@ def fake_quantize(
115
183
zp = zero_point [:, i ].view (- 1 , 1 )
116
184
117
185
idx = i * group_size
118
- Q = quantize (x [:, idx : (idx + group_size )], sc , zp , args )
119
- DQ [:, idx : (idx + group_size )] = dequantize (Q , sc , zp )
186
+ if do_quantize :
187
+ output [:, idx : (idx + group_size )] = _quantize (
188
+ x [:, idx : (idx + group_size )], sc , zp , q_min , q_max , dtype = dtype
189
+ )
190
+ if do_dequantize :
191
+ input = (
192
+ output [:, idx : (idx + group_size )]
193
+ if do_quantize
194
+ else x [:, idx : (idx + group_size )]
195
+ )
196
+ output [:, idx : (idx + group_size )] = _dequantize (input , sc , zp )
120
197
121
198
# channel-wise
122
199
elif args .strategy == QuantizationStrategy .CHANNEL : # group_size == -1
123
- # before: scale shape = [channel_size]
124
- # after: scale shape = [1, channel_size]
125
- scale = scale .unsqueeze (0 )
126
- zero_point = zero_point .unsqueeze (0 )
127
-
128
- Q = quantize (x , scale , zero_point , args )
129
- DQ = dequantize (Q , scale , zero_point )
200
+ if do_quantize :
201
+ output = _quantize (x , scale , zero_point , q_min , q_max , dtype = dtype )
202
+ if do_dequantize :
203
+ output = _dequantize (output if do_quantize else x , scale , zero_point )
130
204
131
205
# per-token
132
206
elif args .strategy == QuantizationStrategy .TOKEN :
@@ -138,14 +212,18 @@ def fake_quantize(
138
212
scale = scale .unsqueeze (1 )
139
213
zero_point = zero_point .unsqueeze (1 )
140
214
141
- Q = quantize (x , scale , zero_point , args )
142
- DQ = dequantize (Q , scale , zero_point )
215
+ if do_quantize :
216
+ output = _quantize (x , scale , zero_point , q_min , q_max , dtype = dtype )
217
+ if do_dequantize :
218
+ output = _dequantize (output if do_quantize else x , scale , zero_point )
143
219
144
220
else :
145
- Q = quantize (x , scale , zero_point , args )
146
- DQ = dequantize (Q , scale , zero_point )
221
+ if do_quantize :
222
+ output = _quantize (x , scale , zero_point , q_min , q_max , dtype = dtype )
223
+ if do_dequantize :
224
+ output = _dequantize (output if do_quantize else x , scale , zero_point )
147
225
148
- return DQ
226
+ return output
149
227
150
228
151
229
def wrap_module_forward_quantized (module : Module , scheme : QuantizationScheme ):
@@ -223,3 +301,33 @@ def maybe_calibrate_or_quantize(
223
301
scale .data = updated_scale .to (device )
224
302
zero_point .data = updated_zero_point .to (device )
225
303
return fake_quantize (value , scale , zero_point , args )
304
+
305
+
306
+ @torch .no_grad ()
307
+ def _quantize (
308
+ x : torch .Tensor ,
309
+ scale : torch .Tensor ,
310
+ zero_point : torch .Tensor ,
311
+ q_min : torch .Tensor ,
312
+ q_max : torch .Tensor ,
313
+ dtype : Optional [torch .dtype ] = None ,
314
+ ) -> torch .Tensor :
315
+ quantized_value = torch .clamp (
316
+ torch .round (x / scale + zero_point ),
317
+ q_min ,
318
+ q_max ,
319
+ )
320
+
321
+ if dtype is not None :
322
+ quantized_value = quantized_value .to (dtype )
323
+
324
+ return quantized_value
325
+
326
+
327
+ @torch .no_grad ()
328
+ def _dequantize (
329
+ x_q : torch .Tensor ,
330
+ scale : torch .Tensor ,
331
+ zero_point : torch .Tensor ,
332
+ ) -> torch .Tensor :
333
+ return (x_q - zero_point ) * scale
0 commit comments