@@ -94,11 +94,17 @@ def dequantize(
94
94
if scale .ndim == 0 :
95
95
args = QuantizationArgs (strategy = QuantizationStrategy .TENSOR )
96
96
elif scale .ndim == 2 :
97
- args = QuantizationArgs (strategy = QuantizationStrategy .CHANNEL )
98
- elif scale .ndim == 3 :
99
- group_size = int (x_q .shape [1 ] / scale .shape [1 ])
100
- args = QuantizationArgs (
101
- strategy = QuantizationStrategy .GROUP , group_size = group_size
97
+ if scale .shape [1 ] == 1 :
98
+ args = QuantizationArgs (strategy = QuantizationStrategy .CHANNEL )
99
+ else :
100
+ group_size = int (x_q .shape [1 ] / scale .shape [1 ])
101
+ args = QuantizationArgs (
102
+ strategy = QuantizationStrategy .GROUP , group_size = group_size
103
+ )
104
+ else :
105
+ raise ValueError (
106
+ f"Could not infer a quantization strategy from scale with { scale .ndim } "
107
+ "dimmensions. Expected 0 or 2 dimmensions."
102
108
)
103
109
return _process_quantization (
104
110
x = x_q ,
@@ -155,14 +161,12 @@ def _process_quantization(
155
161
group_size = args .group_size
156
162
157
163
if args .strategy == QuantizationStrategy .GROUP :
158
-
159
- if do_dequantize :
160
- # if dequantizing the output should match the original weight dtype,
161
- # which is the same as the scale's
164
+ if do_dequantize and not do_quantize :
165
+ # if dequantizing a quantized type infer the output type from the scale
162
166
output = torch .zeros_like (x , dtype = scale .dtype )
163
167
else :
164
- # outputting a quantized output, use the dtype passed in as a kwarg if its
165
- # specified, otherwise default to the input type
168
+ # use the dtype passed in as a kwarg if its specified, otherwise default
169
+ # to the input type
166
170
output_dtype = dtype if dtype is not None else x .dtype
167
171
if output_dtype is FP8_DTYPE :
168
172
# zeros_like doesn't support fp8 types directly, workaround
0 commit comments