Skip to content

Commit 9909726

Browse files
authored
Enable ZP Support for Machete (#20268)
Signed-off-by: czhu-cohere <[email protected]>
1 parent 22e9d42 commit 9909726

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

benchmarks/kernels/benchmark_machete.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
234234

235235
fn = lambda: ops.gptq_marlin_gemm(
236236
a=bt.a,
237+
c=None,
237238
b_q_weight=w_q,
238239
b_scales=w_s,
240+
global_scale=None,
239241
b_zeros=w_zp,
240242
g_idx=g_idx,
241243
perm=sort_indices,

tests/kernels/quantization/test_machete_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
139139

140140
def group_size_valid(shape: tuple[int, int, int],
141141
group_size: Optional[int]) -> bool:
142-
return group_size is None or group_size == -1 or group_size % shape[2] == 0
142+
return group_size is None or group_size == -1 or shape[2] % group_size == 0
143143

144144

145145
def machete_quantize_and_pack(atype: torch.dtype,

vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ def can_implement(cls,
3333
return False, "Act reordering currently not supported by Machete, "\
3434
"when the input features are partitioned across "\
3535
"devices"
36-
if c.zero_points:
37-
return False, "Zero points currently not supported by Machete"
3836

3937
if c.weight_type not in query_machete_supported_quant_types(
4038
c.zero_points):
@@ -53,6 +51,7 @@ def can_implement(cls,
5351
# note assumes that
5452
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
5553
# `weight_scale` is: {input_dim = 0, output_dim = 1}
54+
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
5655
def process_weights_after_loading(self, layer: torch.nn.Module):
5756
c = self.config
5857

@@ -90,16 +89,29 @@ def transform_w_s(x):
9089
x.data = x.data.contiguous()
9190
return x
9291

92+
def transform_w_zp(x):
93+
assert isinstance(x, BasevLLMParameter)
94+
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
95+
x_unpacked = unpack_quantized_values_into_int32(x.data,
96+
c.weight_type,
97+
packed_dim=1)
98+
w_s = getattr(layer, self.w_s_name).data
99+
# pre-apply scales to zero-points
100+
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
101+
return x
102+
93103
# Repack weights and scales for Machete
94104
self._transform_param(layer, self.w_q_name, transform_w_q)
95105
self._transform_param(layer, self.w_s_name, transform_w_s)
106+
if c.zero_points:
107+
self._transform_param(layer, self.w_zp_name, transform_w_zp)
96108

97109
def apply_weights(self,
98110
layer: torch.nn.Module,
99111
x: torch.Tensor,
100112
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
101113
c = self.config
102-
w_q, w_s, _, _ = self._get_weight_params(layer)
114+
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
103115

104116
x_2d = x.reshape(-1, x.shape[-1])
105117
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
@@ -110,7 +122,7 @@ def apply_weights(self,
110122
output = ops.machete_mm(a=x_2d,
111123
b_q=w_q,
112124
b_type=c.weight_type,
113-
b_group_zeros=None,
125+
b_group_zeros=w_zp,
114126
b_group_scales=w_s,
115127
b_group_size=c.group_size)
116128

0 commit comments

Comments
 (0)