@@ -33,8 +33,6 @@ def can_implement(cls,
33
33
return False , "Act reordering currently not supported by Machete, " \
34
34
"when the input features are partitioned across " \
35
35
"devices"
36
- if c .zero_points :
37
- return False , "Zero points currently not supported by Machete"
38
36
39
37
if c .weight_type not in query_machete_supported_quant_types (
40
38
c .zero_points ):
@@ -53,6 +51,7 @@ def can_implement(cls,
53
51
# note assumes that
54
52
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
55
53
# `weight_scale` is: {input_dim = 0, output_dim = 1}
54
+ # `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
56
55
def process_weights_after_loading (self , layer : torch .nn .Module ):
57
56
c = self .config
58
57
@@ -90,16 +89,29 @@ def transform_w_s(x):
90
89
x .data = x .data .contiguous ()
91
90
return x
92
91
92
+ def transform_w_zp (x ):
93
+ assert isinstance (x , BasevLLMParameter )
94
+ permute_param_layout_ (x , input_dim = 0 , output_dim = 1 , packed_dim = 1 )
95
+ x_unpacked = unpack_quantized_values_into_int32 (x .data ,
96
+ c .weight_type ,
97
+ packed_dim = 1 )
98
+ w_s = getattr (layer , self .w_s_name ).data
99
+ # pre-apply scales to zero-points
100
+ x .data = (- 1.0 * w_s * (x_unpacked .to (w_s .dtype ))).contiguous ()
101
+ return x
102
+
93
103
# Repack weights and scales for Machete
94
104
self ._transform_param (layer , self .w_q_name , transform_w_q )
95
105
self ._transform_param (layer , self .w_s_name , transform_w_s )
106
+ if c .zero_points :
107
+ self ._transform_param (layer , self .w_zp_name , transform_w_zp )
96
108
97
109
def apply_weights (self ,
98
110
layer : torch .nn .Module ,
99
111
x : torch .Tensor ,
100
112
bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
101
113
c = self .config
102
- w_q , w_s , _ , _ = self ._get_weight_params (layer )
114
+ w_q , w_s , w_zp , _ = self ._get_weight_params (layer )
103
115
104
116
x_2d = x .reshape (- 1 , x .shape [- 1 ])
105
117
out_shape = x .shape [:- 1 ] + (c .partition_weight_shape [1 ], )
@@ -110,7 +122,7 @@ def apply_weights(self,
110
122
output = ops .machete_mm (a = x_2d ,
111
123
b_q = w_q ,
112
124
b_type = c .weight_type ,
113
- b_group_zeros = None ,
125
+ b_group_zeros = w_zp ,
114
126
b_group_scales = w_s ,
115
127
b_group_size = c .group_size )
116
128
0 commit comments