Skip to content

Commit 18dad65

Browse files
author
Sara Adkins
committed
Merge branch 'main' into sa/fp8
2 parents ac10ffd + 5896adc commit 18dad65

File tree

11 files changed

+101
-29
lines changed

11 files changed

+101
-29
lines changed

src/compressed_tensors/compressors/marlin_24.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def validate_quant_compatability(
6262
group_size = quant_args.group_size
6363
symmetric = quant_args.symmetric
6464
if (
65-
strategy is not QuantizationStrategy.GROUP
66-
and strategy is not QuantizationStrategy.CHANNEL
65+
strategy is not QuantizationStrategy.GROUP.value
66+
and strategy is not QuantizationStrategy.CHANNEL.value
6767
):
6868
raise ValueError(
6969
f"Marlin24 Compressor is only valid for group and channel "

src/compressed_tensors/compressors/model_compressor.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
import re
2020
from copy import deepcopy
21-
from typing import Dict, Optional, Union
21+
from typing import Any, Dict, Optional, Union
2222

2323
import torch
2424
import transformers
@@ -91,20 +91,41 @@ def from_pretrained(
9191
"""
9292
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
9393
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
94+
return cls.from_compression_config(compression_config)
95+
96+
@classmethod
97+
def from_compression_config(cls, compression_config: Dict[str, Any]):
98+
"""
99+
:param compression_config: compression/quantization config dictionary
100+
found under key "quantization_config" in HF model config
101+
:return: compressor for the extracted configs
102+
"""
94103
if compression_config is None:
95104
return None
96105

106+
try:
107+
from transformers.utils.quantization_config import CompressedTensorsConfig
108+
109+
if isinstance(compression_config, CompressedTensorsConfig):
110+
compression_config = compression_config.to_dict()
111+
except ImportError:
112+
pass
113+
97114
sparsity_config = cls.parse_sparsity_config(compression_config)
98115
quantization_config = cls.parse_quantization_config(compression_config)
99116
if sparsity_config is None and quantization_config is None:
100117
return None
101118

102-
if sparsity_config is not None:
119+
if sparsity_config is not None and not isinstance(
120+
sparsity_config, SparsityCompressionConfig
121+
):
103122
format = sparsity_config.get("format")
104123
sparsity_config = SparsityCompressionConfig.load_from_registry(
105124
format, **sparsity_config
106125
)
107-
if quantization_config is not None:
126+
if quantization_config is not None and not isinstance(
127+
quantization_config, QuantizationConfig
128+
):
108129
quantization_config = QuantizationConfig.parse_obj(quantization_config)
109130

110131
return cls(
@@ -149,15 +170,29 @@ def from_pretrained_model(
149170
def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
150171
if compression_config is None:
151172
return None
173+
if SPARSITY_CONFIG_NAME not in compression_config:
174+
return None
175+
if hasattr(compression_config, SPARSITY_CONFIG_NAME):
176+
# for loaded HFQuantizer config
177+
return getattr(compression_config, SPARSITY_CONFIG_NAME)
178+
179+
# SparseAutoModel format
152180
return compression_config.get(SPARSITY_CONFIG_NAME, None)
153181

154182
@staticmethod
155183
def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
184+
if compression_config is None:
185+
return None
186+
187+
if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
188+
# for loaded HFQuantizer config
189+
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
190+
191+
# SparseAutoModel format
156192
quantization_config = deepcopy(compression_config)
157193
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
158194
if len(quantization_config) == 0:
159195
quantization_config = None
160-
161196
return quantization_config
162197

163198
def __init__(

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,14 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
123123
if target is not None:
124124
# target matched - add layer and scheme to target list
125125
submodule.quantization_scheme = target_to_scheme[target]
126-
if set(config.ignore) - set(ignored_submodules):
127-
_LOGGER.warning(
128-
"Some layers that were to be ignored were "
129-
f"not found in the model: {set(config.ignore) - set(ignored_submodules)}"
130-
)
126+
127+
if config.ignore is not None and ignored_submodules is not None:
128+
if set(config.ignore) - set(ignored_submodules):
129+
_LOGGER.warning(
130+
"Some layers that were to be ignored were "
131+
"not found in the model: "
132+
f"{set(config.ignore) - set(ignored_submodules)}"
133+
)
131134
# apply current quantization status across all targeted layers
132135
apply_quantization_status(model, config.quantization_status)
133136

@@ -146,7 +149,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
146149

147150
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
148151
model.apply(set_module_for_calibration)
149-
150152
if current_status < status >= QuantizationStatus.FROZEN > current_status:
151153
model.apply(freeze_module_quantization)
152154

@@ -160,9 +162,10 @@ def find_first_name_or_class_match(
160162
# first element of targets that matches the given name
161163
# if no name matches returns first target that matches the class name
162164
# returns None otherwise
163-
return _find_first_match(name, targets) or _find_first_match(
164-
module.__class__.__name__, targets, check_contains
165-
)
165+
if isinstance(targets, Iterable):
166+
return _find_first_match(name, targets) or _find_first_match(
167+
module.__class__.__name__, targets, check_contains
168+
)
166169

167170

168171
def _find_first_match(

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def dequantize(
9898
:return: dequantized float tensor
9999
"""
100100
if args is None:
101-
if scale.ndim == 0:
101+
if scale.ndim == 0 or scale.ndim == 1:
102102
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
103103
elif scale.ndim == 2:
104104
if scale.shape[1] == 1:

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from compressed_tensors.quantization.lifecycle.forward import (
2121
wrap_module_forward_quantized,
2222
)
23-
from compressed_tensors.quantization.quant_args import QuantizationArgs
23+
from compressed_tensors.quantization.quant_args import (
24+
QuantizationArgs,
25+
QuantizationStrategy,
26+
)
2427
from compressed_tensors.quantization.quant_config import QuantizationStatus
2528
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
2629
from torch.nn import Module, Parameter
@@ -58,7 +61,12 @@ def initialize_module_for_quantization(
5861
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
5962
if scheme.weights is not None:
6063
if hasattr(module, "weight"):
61-
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
64+
weight_shape = None
65+
if isinstance(module, torch.nn.Linear):
66+
weight_shape = module.weight.shape
67+
_initialize_scale_zero_point_observer(
68+
module, "weight", scheme.weights, weight_shape=weight_shape
69+
)
6270
else:
6371
_LOGGER.warning(
6472
f"module type {type(module)} targeted for weight quantization but "
@@ -78,7 +86,10 @@ def initialize_module_for_quantization(
7886

7987

8088
def _initialize_scale_zero_point_observer(
81-
module: Module, base_name: str, quantization_args: QuantizationArgs
89+
module: Module,
90+
base_name: str,
91+
quantization_args: QuantizationArgs,
92+
weight_shape: Optional[torch.Size] = None,
8293
):
8394
# initialize observer module and attach as submodule
8495
observer = quantization_args.get_observer()
@@ -89,14 +100,29 @@ def _initialize_scale_zero_point_observer(
89100

90101
device = next(module.parameters()).device
91102

103+
# infer expected scale/zero point shape
104+
expected_shape = 1 # per tensor
105+
106+
if base_name == "weight" and weight_shape is not None:
107+
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
108+
# (output_channels, 1)
109+
expected_shape = (weight_shape[0], 1)
110+
elif quantization_args.strategy == QuantizationStrategy.GROUP:
111+
expected_shape = (
112+
weight_shape[0],
113+
weight_shape[1] // quantization_args.group_size,
114+
)
115+
92116
# initializes empty scale and zero point parameters for the module
93117
init_scale = Parameter(
94-
torch.empty(0, dtype=module.weight.dtype, device=device), requires_grad=False
118+
torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
119+
requires_grad=False,
95120
)
96121
module.register_parameter(f"{base_name}_scale", init_scale)
97122

98123
zp_dtype = quantization_args.pytorch_dtype()
99124
init_zero_point = Parameter(
100-
torch.empty(0, device=device, dtype=zp_dtype), requires_grad=False
125+
torch.empty(expected_shape, device=device, dtype=int),
126+
requires_grad=False,
101127
)
102128
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

src/compressed_tensors/quantization/observers/helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def calculate_qparams(
5959
# match zero-points to quantized type
6060
zero_points = zero_points.to(zp_dtype)
6161

62+
if scales.ndim == 0:
63+
scales = scales.reshape(1)
64+
zero_points = zero_points.reshape(1)
65+
6266
return scales, zero_points
6367

6468

src/compressed_tensors/quantization/quant_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class QuantizationStrategy(str, Enum):
5151
TOKEN = "token"
5252

5353

54-
class QuantizationArgs(BaseModel):
54+
class QuantizationArgs(BaseModel, use_enum_values=True):
5555
"""
5656
User facing arguments used to define a quantization config for weights or
5757
activations
@@ -71,7 +71,7 @@ class QuantizationArgs(BaseModel):
7171
"""
7272

7373
num_bits: int = 8
74-
type: QuantizationType = QuantizationType.INT
74+
type: QuantizationType = QuantizationType.INT.value
7575
symmetric: bool = True
7676
group_size: Optional[int] = None
7777
strategy: Optional[QuantizationStrategy] = None

src/compressed_tensors/quantization/quant_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def model_post_init(self, __context):
144144
targets=targets_or_scheme,
145145
)
146146

147+
def to_dict(self):
148+
# for compatibility with HFQuantizer
149+
return self.dict()
150+
147151
@staticmethod
148152
def from_pretrained(
149153
model: Module, format: Optional[str] = None

src/compressed_tensors/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datetime import date
2020

2121

22-
version_base = "0.3.3"
22+
version_base = "0.4.0"
2323
is_release = False # change to True to set the generated version as a release version
2424

2525

tests/test_quantization/lifecycle/test_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_st
5757
quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=True)
5858
layer = Linear(4, 4)
5959
layer.weight.data *= 100
60+
layer.quantization_status = QuantizationStatus(quantization_status)
6061

6162
initialize_module_for_quantization(layer, quantization_scheme)
62-
layer.quantization_status = QuantizationStatus(quantization_status)
6363

6464
# only calibration updates the scale and zero-point
6565
if layer.quantization_status == QuantizationStatus.INITIALIZED:

0 commit comments

Comments
 (0)