@@ -48,18 +48,33 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
48
48
requires_grad = False )
49
49
# INPUT SCALE
50
50
if self .is_static_input_scheme :
51
- layer .input_scale = Parameter (layer .input_scale .max (),
52
- requires_grad = False )
53
- if not self .input_symmetric :
54
- layer .input_zero_point = Parameter (layer .input_zero_point ,
55
- requires_grad = False )
51
+ if self .input_symmetric :
52
+ layer .input_scale = Parameter (layer .input_scale .max (),
53
+ requires_grad = False )
56
54
else :
57
- layer .input_zero_point = None
55
+ raise NotImplementedError (
56
+ "static input asymmetric quantization not supported yet" )
57
+ # reconstruct the ranges
58
+ int8_traits = torch .iinfo (torch .int8 )
59
+ range_max = (layer .input_scale *
60
+ (int8_traits .max - layer .input_zero_point )).max ()
61
+ range_min = (layer .input_scale *
62
+ (int8_traits .min - layer .input_zero_point )).min ()
63
+
64
+ scale = (range_max - range_min ) / (int8_traits .max -
65
+ int8_traits .min )
66
+ layer .input_scale = Parameter (scale , requires_grad = False )
67
+
68
+ azp = int8_traits .min - range_min / scale
69
+ layer .input_zero_point = Parameter (azp , requires_grad = False )
70
+
58
71
else :
59
72
layer .input_scale = None
60
73
layer .input_zero_point = None
61
74
62
75
if not self .input_symmetric :
76
+ # azp_adj is the AZP adjustment term, used to account for weights.
77
+ # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
63
78
layer .azp_adj = layer .weight .sum (dim = 0 ,
64
79
keepdim = True ,
65
80
dtype = torch .int32 )
@@ -108,7 +123,7 @@ def create_weights(self, layer: torch.nn.Module,
108
123
if not self .input_symmetric :
109
124
raise NotImplementedError (
110
125
"static input asymmetric quantization not supported yet" )
111
- input_zero_point = Parameter (torch .zeros (1 , dtype = torch .int8 ))
126
+ input_zero_point = Parameter (torch .zeros (1 , dtype = torch .int32 ))
112
127
layer .register_parameter ("input_zero_point" , input_zero_point )
113
128
114
129
def apply_weights (self , layer : torch .nn .Module , x : torch .Tensor ,
0 commit comments