Skip to content

Commit ac10ffd

Browse files
author
Sara Adkins
committed
Merge branch 'main' into sa/fp8
2 parents 944e27f + 14b1db1 commit ac10ffd

File tree

10 files changed

+833
-6
lines changed

10 files changed

+833
-6
lines changed

src/compressed_tensors/compressors/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from .base import Compressor
1818
from .dense import DenseCompressor
1919
from .helpers import load_compressed, save_compressed, save_compressed_model
20-
from .model_compressor import ModelCompressor
20+
from .marlin_24 import Marlin24Compressor
21+
from .model_compressor import ModelCompressor, map_modules_to_quant_args
2122
from .naive_quantized import (
2223
FloatQuantizationCompressor,
2324
IntQuantizationCompressor,
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from typing import Dict, Generator, Tuple
17+
18+
import numpy as np
19+
import torch
20+
from compressed_tensors.compressors import Compressor
21+
from compressed_tensors.compressors.utils import (
22+
get_permutations_24,
23+
sparse_semi_structured_from_dense_cutlass,
24+
tensor_follows_mask_structure,
25+
)
26+
from compressed_tensors.config import CompressionFormat
27+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
28+
from compressed_tensors.quantization.lifecycle.forward import quantize
29+
from compressed_tensors.utils import is_quantization_param, merge_names
30+
from torch import Tensor
31+
from tqdm import tqdm
32+
33+
34+
_LOGGER: logging.Logger = logging.getLogger(__name__)
35+
36+
37+
@Compressor.register(name=CompressionFormat.marlin_24.value)
38+
class Marlin24Compressor(Compressor):
39+
"""
40+
Compresses a quantized model with 2:4 sparsity structure for inference with the
41+
Marlin24 kernel. Decompression is not implemented for this compressor.
42+
"""
43+
44+
COMPRESSION_PARAM_NAMES = ["weight_packed", "scale_packed", "meta"]
45+
46+
@staticmethod
47+
def validate_quant_compatability(
48+
model_quant_args: Dict[str, QuantizationArgs]
49+
) -> bool:
50+
"""
51+
Checks if every quantized module in the model is compatible with Marlin24
52+
compression. Quantization must be channel or group strategy with group_size
53+
of 128. Only symmetric quantization is supported
54+
55+
:param model_quant_args: dictionary of mapping module names to their
56+
quantization configuration
57+
:return: True if all modules are compatible with Marlin24 compression, raises
58+
a ValueError otherwise
59+
"""
60+
for name, quant_args in model_quant_args.items():
61+
strategy = quant_args.strategy
62+
group_size = quant_args.group_size
63+
symmetric = quant_args.symmetric
64+
if (
65+
strategy is not QuantizationStrategy.GROUP
66+
and strategy is not QuantizationStrategy.CHANNEL
67+
):
68+
raise ValueError(
69+
f"Marlin24 Compressor is only valid for group and channel "
70+
f"quantization strategies, got {strategy} in {name}"
71+
)
72+
73+
if group_size is not None and group_size != 128:
74+
raise ValueError(
75+
f"Marlin24 Compressor is only valid for group size 128, "
76+
f"got {group_size} in {name}"
77+
)
78+
79+
if not symmetric:
80+
raise ValueError(
81+
f"Marlin24 Compressor is only valid for symmetric quantzation, "
82+
f"got symmetric={symmetric} in {name}"
83+
)
84+
85+
return True
86+
87+
@staticmethod
88+
def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
89+
"""
90+
Checks if a tensor fits the required 2:4 sparsity structure
91+
92+
:param name: name of the tensor to check
93+
:param weight: tensor to check for sparsity structure
94+
:return: True if all rows match the 2:4 sparsity structure, raises
95+
ValueError otherwise
96+
"""
97+
98+
if not tensor_follows_mask_structure(weight):
99+
raise ValueError(
100+
"Marlin24 Compressor is only compatible with weights that have "
101+
f"a 2:4 sparsity structure. Found segments in {name} "
102+
"that do not match the expected structure."
103+
)
104+
105+
return True
106+
107+
def compress(
108+
self,
109+
model_state: Dict[str, Tensor],
110+
model_quant_args: Dict[str, QuantizationArgs],
111+
**kwargs,
112+
) -> Dict[str, Tensor]:
113+
"""
114+
Compresses a quantized state_dict with 2:4 sparsity structure for inference
115+
with the Marlin24 kernel
116+
117+
:param model_state: state dict of uncompressed model
118+
:param model_quant_args: quantization args for each quantized weight, needed for
119+
quantize function to calculate bit depth
120+
:return: compressed state dict
121+
"""
122+
self.validate_quant_compatability(model_quant_args)
123+
124+
compressed_dict = {}
125+
weight_suffix = ".weight"
126+
_LOGGER.debug(
127+
f"Compressing model with {len(model_state)} parameterized layers..."
128+
)
129+
130+
for name, value in tqdm(model_state.items(), desc="Compressing model"):
131+
if name.endswith(weight_suffix):
132+
prefix = name[: -(len(weight_suffix))]
133+
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
134+
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
135+
if scale is not None: # weight is quantized, compress it
136+
137+
# Marlin24 kernel requires float16 inputs
138+
scale = scale.to(torch.float16)
139+
value = value.to(torch.float16)
140+
141+
# quantize weight, keeping it as a float16 for now
142+
quant_args = model_quant_args[prefix]
143+
value = quantize(
144+
x=value, scale=scale, zero_point=zp, args=quant_args
145+
)
146+
147+
# compress based on sparsity structure
148+
self.validate_sparsity_structure(prefix, value)
149+
value, meta = compress_weight_24(value)
150+
meta = meta.cpu()
151+
152+
# Marlin24 kernel expects input dim first
153+
value = value.t().contiguous().cpu()
154+
scale = scale.t().contiguous().cpu()
155+
og_weight_shape = value.shape
156+
157+
# Marlin24 kernel expects unsigned values, shift zero-point
158+
value += (1 << quant_args.num_bits) // 2
159+
160+
# pack quantized weight and scale
161+
value = pack_weight_24(value, quant_args)
162+
packed_scale = pack_scales_24(scale, quant_args, og_weight_shape)
163+
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
164+
165+
# save compressed values
166+
compressed_dict[merge_names(prefix, "scale_packed")] = packed_scale
167+
compressed_dict[merge_names(prefix, "weight_packed")] = value
168+
compressed_dict[merge_names(prefix, "meta")] = meta
169+
continue
170+
171+
if not is_quantization_param(name):
172+
# export unquantized parameters without modifying
173+
compressed_dict[name] = value.to("cpu")
174+
175+
return compressed_dict
176+
177+
def decompress(
178+
self, path_to_model_or_tensors: str, device: str = "cpu"
179+
) -> Generator[Tuple[str, Tensor], None, None]:
180+
raise NotImplementedError(
181+
"Decompression is not implemented for the Marlin24 Compressor."
182+
)
183+
184+
185+
def compress_weight_24(weight: Tensor):
186+
weight = weight.contiguous()
187+
w_comp, meta = sparse_semi_structured_from_dense_cutlass(weight)
188+
w_comp = w_comp.contiguous()
189+
return w_comp, meta
190+
191+
192+
def marlin_permute_weights(q_w, size_k, size_n, perm, tile):
193+
assert q_w.shape == (size_k, size_n)
194+
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
195+
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
196+
197+
# Permute weights to 16x64 marlin tiles
198+
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
199+
q_w = q_w.permute((0, 2, 1, 3))
200+
q_w = q_w.reshape((size_k // tile, size_n * tile))
201+
202+
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
203+
204+
return q_w
205+
206+
207+
def pack_weight_24(
208+
weight: Tensor,
209+
quantization_args: QuantizationArgs,
210+
tile: int = 16,
211+
):
212+
size_k = weight.shape[0]
213+
size_n = weight.shape[1]
214+
num_bits = quantization_args.num_bits
215+
pack_factor = 32 // num_bits
216+
217+
# Reshuffle to marlin_24 format
218+
perm, _, _ = get_permutations_24(num_bits)
219+
q_w = marlin_permute_weights(weight, size_k, size_n, perm, tile)
220+
221+
q_w = q_w.cpu().numpy().astype(np.uint32)
222+
223+
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
224+
for i in range(pack_factor):
225+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
226+
227+
q_packed = torch.from_numpy(q_packed.astype(np.int32))
228+
229+
return q_packed
230+
231+
232+
def pack_scales_24(scales, quantization_args, w_shape):
233+
size_k = w_shape[0]
234+
size_n = w_shape[1]
235+
num_bits = quantization_args.num_bits
236+
237+
_, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
238+
239+
if (
240+
quantization_args.strategy is QuantizationStrategy.GROUP
241+
and quantization_args.group_size < size_k
242+
):
243+
scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
244+
else: # channelwise
245+
scales = scales.reshape((-1, len(scale_perm_single_2_4)))[
246+
:, scale_perm_single_2_4
247+
]
248+
scales = scales.reshape((-1, size_n)).contiguous()
249+
250+
return scales

src/compressed_tensors/compressors/model_compressor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from transformers.file_utils import CONFIG_NAME
4949

5050

51-
__all__ = ["ModelCompressor"]
51+
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]
5252

5353
_LOGGER: logging.Logger = logging.getLogger(__name__)
5454

@@ -193,7 +193,7 @@ def compress(
193193
state_dict = model.state_dict()
194194

195195
compressed_state_dict = state_dict
196-
quantized_modules_to_args = _get_weight_arg_mappings(model)
196+
quantized_modules_to_args = map_modules_to_quant_args(model)
197197
if self.quantization_compressor is not None:
198198
compressed_state_dict = self.quantization_compressor.compress(
199199
state_dict, model_quant_args=quantized_modules_to_args
@@ -277,7 +277,7 @@ def _replace_weights(self, dense_weight_generator, model):
277277
data_old.data = data_new.data
278278

279279

280-
def _get_weight_arg_mappings(model: Module) -> Dict:
280+
def map_modules_to_quant_args(model: Module) -> Dict:
281281
quantized_modules_to_args = {}
282282
for name, submodule in iter_named_leaf_modules(model):
283283
if is_module_quantized(submodule):
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
17+
from .helpers import *
18+
from .permutations_24 import *
19+
from .semi_structured_conversions import *
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
18+
__all__ = ["tensor_follows_mask_structure"]
19+
20+
21+
def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
22+
"""
23+
:param tensor: tensor to check
24+
:param mask: mask structure to check for, in the format "n:m"
25+
:return: True if the tensor follows the mask structure, False otherwise.
26+
Note, some weights can incidentally be zero, so we check for
27+
atleast n zeros in each chunk of size m
28+
"""
29+
30+
n, m = tuple(map(int, mask.split(":")))
31+
# Reshape the tensor into chunks of size m
32+
tensor = tensor.view(-1, m)
33+
34+
# Count the number of zeros in each chunk
35+
zero_counts = (tensor == 0).sum(dim=1)
36+
37+
# Check if the number of zeros in each chunk atleast n
38+
# Greater than sign is needed as some weights can incidentally
39+
# be zero
40+
if not torch.all(zero_counts >= n).item():
41+
raise ValueError()
42+
43+
return True

0 commit comments

Comments
 (0)