Skip to content

Activation Ordering #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 84 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
cfbe5e9
propagate g_idx to scale, zp
horheynm Jun 24, 2024
27ee8d9
intialize empty g_idx
horheynm Jun 25, 2024
4918140
get rid of g_idx in forward
horheynm Jun 26, 2024
88908da
g_idx in forward
horheynm Jul 2, 2024
c884a3e
g_idx in _process_quantization
horheynm Jul 3, 2024
460dfbd
remove unnec comments
horheynm Jul 8, 2024
a19788a
initialize param to -1, g_idx in forward call
horheynm Jul 9, 2024
9607bdb
comments
horheynm Jul 12, 2024
00b8772
add ability to disable forward pass override during quantization
Jul 18, 2024
3ae6698
add pathway for updating weight quant params from curent weight
Jul 18, 2024
d472287
add g_idx to observer
horheynm Jul 18, 2024
6b13bd2
fix on load
horheynm Jul 25, 2024
c1eb3b4
fix prepopulated observer params
horheynm Jul 26, 2024
ae2c158
draft
horheynm Jul 27, 2024
8069335
draft
horheynm Jul 31, 2024
0b2e9ab
draft
horheynm Aug 1, 2024
93749ef
two bug fixes: pass g_idx during compress and decompress, deep copy w…
kylesayrs Aug 13, 2024
bd92a10
remove perm from , add and remove comments
kylesayrs Aug 13, 2024
d51a8b6
implement `refresh_layer_weight_quant_params`
kylesayrs Aug 13, 2024
5025fb6
update_layer_weight_quant_params reuse
kylesayrs Aug 14, 2024
bb35198
simplify grouped forward logic
kylesayrs Aug 14, 2024
ddc015a
rename group_id to group_index
kylesayrs Aug 14, 2024
03a2ce9
3 lines to 1 line
kylesayrs Aug 14, 2024
33968a1
clarify defaulting logic
kylesayrs Aug 14, 2024
ecf731e
initialize with g_idx if applicable
kylesayrs Aug 14, 2024
4730ec3
update comment
kylesayrs Aug 14, 2024
0aeacfc
add back import
kylesayrs Aug 14, 2024
2e6d6e1
implement quantization cases of varying speeds
kylesayrs Aug 14, 2024
0f804aa
correct typo
kylesayrs Aug 14, 2024
4ef1c04
add g_idx and different quantization cases
kylesayrs Aug 14, 2024
8797eb9
rearrange todos
kylesayrs Aug 14, 2024
d326914
comments
kylesayrs Aug 14, 2024
a8e2006
better spacing
kylesayrs Aug 14, 2024
18fe91a
apply style
kylesayrs Aug 14, 2024
88ed2cf
consolodate comments
kylesayrs Aug 14, 2024
d257106
add document and validation
kylesayrs Aug 16, 2024
c75bdc1
apply style
kylesayrs Aug 16, 2024
6556cd5
handle case where g_idx is None
kylesayrs Aug 19, 2024
bfdb7f5
use permutation
kylesayrs Aug 20, 2024
23c9f77
use permutation
kylesayrs Aug 20, 2024
bb9d6d7
apply style
kylesayrs Aug 20, 2024
a3adf88
save for future PR
kylesayrs Aug 20, 2024
0558004
apply style
kylesayrs Aug 20, 2024
cf3e4d7
remove unneeded global
kylesayrs Aug 20, 2024
e3b1082
remove dangling comment
kylesayrs Aug 20, 2024
0a1b366
Remove extra line
kylesayrs Aug 20, 2024
588f151
detect experimental types using try except and caching
kylesayrs Aug 20, 2024
ead5497
add note about activation ordering
kylesayrs Aug 20, 2024
fc6faf9
add dtype test
kylesayrs Aug 20, 2024
59d8c97
use sets and helper function
kylesayrs Aug 20, 2024
029473b
fix bugs
kylesayrs Aug 20, 2024
11af88e
apply style
kylesayrs Aug 20, 2024
b9fada2
add quantize dequantize tests, apply style
kylesayrs Aug 20, 2024
d627368
Merge branch 'g-idx-quantization' into act-order
kylesayrs Aug 20, 2024
09b2dad
remove repeat arguments
kylesayrs Aug 20, 2024
acc170b
Merge remote-tracking branch 'origin' into act-order
kylesayrs Aug 21, 2024
c47ef41
better validation logic
kylesayrs Aug 21, 2024
6330cfa
break up tests
kylesayrs Aug 21, 2024
0b0d545
fix validate_group
kylesayrs Aug 21, 2024
5887688
move safe_permute to utils
kylesayrs Aug 21, 2024
09b198d
clean up observer, add tests
kylesayrs Aug 21, 2024
2a96ec1
better spacing
kylesayrs Aug 21, 2024
bffc144
apply style
kylesayrs Aug 21, 2024
793a610
remove merge dreggs
kylesayrs Aug 21, 2024
f1b7b3a
initial commit
kylesayrs Aug 21, 2024
d0e214d
add reset implementations
kylesayrs Aug 21, 2024
42acc6b
fix tests import
kylesayrs Aug 21, 2024
51f13cb
fix tests import
kylesayrs Aug 21, 2024
12799b7
Merge branch 'ksayers/move-safe-permute' into act-order
kylesayrs Aug 21, 2024
30e5955
None is okay
kylesayrs Aug 21, 2024
45a3d6f
remove space
kylesayrs Aug 21, 2024
b42e04c
Merge remote-tracking branch 'origin/main' into act-order
kylesayrs Aug 22, 2024
21d07c3
use correct dtypes for scale and floating point
kylesayrs Aug 22, 2024
89b41c8
initialize to None and empty
kylesayrs Aug 22, 2024
1ea14ce
do not set attribute if not used
kylesayrs Aug 22, 2024
788eb41
initialize with none to standardize on optional quant_params
kylesayrs Aug 22, 2024
8b8c666
load g_idx for naive_quantized compressor
kylesayrs Aug 22, 2024
647f2e5
apply style
kylesayrs Aug 22, 2024
0f61c8a
add depreciation todo
kylesayrs Aug 22, 2024
50bc5e3
use default of -1s
kylesayrs Aug 22, 2024
a77331e
remove comment, do not register parameter if not used
kylesayrs Aug 23, 2024
9777bdf
add tests
kylesayrs Aug 23, 2024
1b869fa
original logic
kylesayrs Aug 23, 2024
7475358
adjust comment
kylesayrs Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/compressed_tensors/compressors/pack_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class PackedQuantizationCompressor(Compressor):
"weight_packed",
"weight_scale",
"weight_zero_point",
"weight_g_idx",
"weight_shape",
]

Expand Down Expand Up @@ -72,6 +73,7 @@ def compress(
prefix = name[: -(len(weight_suffix))]
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
shape = torch.tensor(value.shape)
if scale is not None and zp is not None:
# weight is quantized, compress it
Expand All @@ -82,6 +84,7 @@ def compress(
x=value,
scale=scale,
zero_point=zp,
g_idx=g_idx,
args=quant_args,
dtype=torch.int8,
)
Expand Down Expand Up @@ -128,16 +131,18 @@ def decompress(
weight_data[param_name] = f.get_tensor(full_name)

if "weight_scale" in weight_data:
zero_point = weight_data.get("weight_zero_point", None)
scale = weight_data["weight_scale"]
weight = weight_data["weight_packed"]
scale = weight_data["weight_scale"]
zero_point = weight_data.get("weight_zero_point", None)
g_idx = weight_data.get("weight_g_idx", None)
num_bits = weight_data["num_bits"]
original_shape = torch.Size(weight_data["weight_shape"])
unpacked = unpack_from_int32(weight, num_bits, original_shape)
decompressed = dequantize(
x_q=unpacked,
scale=scale,
zero_point=zero_point,
g_idx=g_idx,
)
yield merge_names(weight_name, "weight"), decompressed

Expand Down
14 changes: 9 additions & 5 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,21 @@ def _load_quant_args_from_state_dict(
"""
scale_name = f"{base_name}_scale"
zp_name = f"{base_name}_zero_point"
g_idx_name = f"{base_name}_g_idx"

state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
state_dict_zp = state_dict.get(
f"{module_name}.{zp_name}", torch.zeros_like(state_dict_scale, device="cpu")
)
state_dict_g_idx = state_dict.get(
f"{module_name}.{g_idx_name}",
torch.full_like(state_dict_scale, -1, device="cpu"),
)

if state_dict_scale is not None:
# module is quantized
update_parameter_data(module, state_dict_scale, scale_name)
if state_dict_zp is None:
# fill in zero point for symmetric quantization
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
update_parameter_data(module, state_dict_zp, zp_name)
update_parameter_data(module, state_dict_g_idx, g_idx_name)


def _scheme_from_targets(
Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it
observer = module.weight_observer
g_idx = getattr(module, "weight_g_idx", None)

offloaded = False
if is_module_offloaded(module):
module._hf_hook.pre_forward(module)
offloaded = True

scale, zero_point = observer(module.weight)
scale, zero_point = observer(module.weight, g_idx=g_idx)
update_parameter_data(module, scale, "weight_scale")
update_parameter_data(module, zero_point, "weight_zero_point")

Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def maybe_calibrate_or_quantize(
if args.dynamic:
# dynamic quantization - get scale and zero point directly from observer
observer = getattr(module, f"{base_name}_observer")
scale, zero_point = observer(value)
scale, zero_point = observer(value, g_idx=g_idx)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
Expand All @@ -344,7 +344,7 @@ def maybe_calibrate_or_quantize(
# calibration mode - get new quant params from observer
observer = getattr(module, f"{base_name}_observer")

updated_scale, updated_zero_point = observer(value)
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)

# update scale and zero point
update_parameter_data(module, updated_scale, f"{base_name}_scale")
Expand Down
8 changes: 8 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,11 @@ def _initialize_scale_zero_point_observer(
requires_grad=False,
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

if quantization_args.actorder:
_, column_size = module.weight.shape
init_g_idx = Parameter(
torch.full((column_size,), -1, dtype=torch.int32, device=device),
requires_grad=False,
)
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
59 changes: 45 additions & 14 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
from math import ceil
from typing import Any, Iterable, Optional, Tuple, Union

import torch
Expand All @@ -21,6 +22,7 @@
QuantizationStrategy,
)
from compressed_tensors.registry.registry import RegistryMixin
from compressed_tensors.utils import safe_permute
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module

Expand All @@ -46,15 +48,18 @@ def __init__(self, quantization_args: QuantizationArgs):
self._num_observed_tokens = None

@torch.no_grad()
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def forward(
self, observed: Tensor, g_idx: Optional[Tensor] = None
) -> Tuple[FloatTensor, IntTensor]:
"""
maps directly to get_qparams
:param observed: optional observed tensor to calculate quantization parameters
from
:param observed: optional observed tensor from which to calculate
quantization parameters
:param g_idx: optional mapping from column index to group index
:return: tuple of scale and zero point based on last observed value
"""
self.record_observed_tokens(observed)
return self.get_qparams(observed=observed)
return self.get_qparams(observed=observed, g_idx=g_idx)

def calculate_qparams(
self,
Expand All @@ -77,7 +82,9 @@ def post_calculate_qparams(self) -> None:
...

def get_qparams(
self, observed: Optional[Tensor] = None
self,
observed: Optional[Tensor] = None,
g_idx: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Convenience function to wrap overwritten calculate_qparams
Expand All @@ -86,6 +93,7 @@ def get_qparams(

:param observed: optional observed tensor to calculate quantization parameters
from
:param g_idx: optional mapping from column index to group index
:return: tuple of scale and zero point based on last observed value
"""
if observed is not None:
Expand All @@ -97,20 +105,41 @@ def get_qparams(
self._scale, self._zero_point = self.calculate_qparams(observed)

elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
rows = observed.shape[0]
columns = observed.shape[1]
scales, zero_points = [], []
group_idxs = range(0, columns, self.quantization_args.group_size)
for group_id, group_idx in enumerate(group_idxs):
num_groups = int(ceil(columns / group_size))
self._scale = torch.empty(
(rows, num_groups), dtype=torch.float, device=observed.device
)
self._zero_point = torch.empty(
(rows, num_groups), dtype=torch.float, device=observed.device
)

# support column-order (default) quantization as well as other orderings
# such as activation ordering. Below checks if g_idx has initialized
is_column_order = g_idx is None or -1 in g_idx
if is_column_order:
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
else:
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
group_sizes = group_sizes[torch.argsort(group_indices)]

perm = torch.argsort(g_idx)
observed = safe_permute(observed, perm, dim=1)

# TODO: experiment with vectorizing for loop for performance
end = 0
for group_index, group_count in enumerate(group_sizes):
start = end
end = start + group_count
scale, zero_point = self.get_qparams_along_dim(
observed[:, group_idx : (group_idx + group_size)],
observed[:, start:end],
0,
tensor_id=group_id,
tensor_id=group_index,
)
scales.append(scale)
zero_points.append(zero_point)

self._scale = torch.cat(scales, dim=1, out=self._scale)
self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
self._scale[:, group_index] = scale.squeeze(1)
self._zero_point[:, group_index] = zero_point.squeeze(1)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
Expand All @@ -132,6 +161,8 @@ def get_qparams_along_dim(
dim: Union[int, Iterable[int]],
tensor_id: Optional[Any] = None,
):
if isinstance(dim, int):
dim = [dim]
dim = set(dim)

reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
Expand Down
66 changes: 47 additions & 19 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Dict, Optional

import torch
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator, model_validator


__all__ = [
Expand Down Expand Up @@ -68,6 +68,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
ranges will be observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization will change the default
observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to False for arbitrary ordering
"""

num_bits: int = 8
Expand All @@ -77,6 +79,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
actorder: bool = False
observer: str = Field(
default="minmax",
description=(
Expand Down Expand Up @@ -105,33 +108,58 @@ def get_observer(self):

return Observer.load_from_registry(self.observer, quantization_args=self)

@validator("strategy", pre=True, always=True)
def validate_strategy(cls, value, values):
group_size = values.get("group_size")
@field_validator("group_size", mode="before")
def validate_group(cls, value) -> int:
if value is None:
return value

# use group_size to determinine strategy if not given explicity
if group_size is not None and value is None:
if group_size > 0:
return QuantizationStrategy.GROUP
if value < -1:
raise ValueError(
f"Invalid group size {value}. Use group_size > 0 for "
"strategy='group' and group_size = -1 for 'channel'"
)

elif group_size == -1:
return QuantizationStrategy.CHANNEL
return value

@model_validator(mode="before")
def validate_strategy(values) -> Dict[str, Any]:
model_fields = QuantizationArgs.model_fields
strategy = values.get("strategy", model_fields["strategy"].default)
group_size = values.get("group_size", model_fields["group_size"].default)
actorder = values.get("actorder", model_fields["actorder"].default)

if strategy is not None:
strategy = QuantizationStrategy(strategy.lower())

else:
# use group_size to determinine strategy if not given explicity
if group_size is None:
strategy = QuantizationStrategy.TENSOR
elif group_size > 0:
strategy = QuantizationStrategy.GROUP
elif group_size == -1:
strategy = QuantizationStrategy.CHANNEL
else:
raise ValueError(
f"group_size={group_size} with strategy {value} is invald. "
"group_size > 0 for strategy='group' and "
"group_size = -1 for 'channel'"
f"Invalid group size {group_size}. Use group_size > 0 for "
"strategy='group' and group_size = -1 for 'channel'"
)

if value == QuantizationStrategy.GROUP:
if group_size is None:
raise ValueError(f"strategy {value} requires group_size to be set.")
if strategy == QuantizationStrategy.GROUP:
if group_size is None or group_size <= 0:
raise ValueError(
f"strategy {strategy} requires group_size to be "
"set to a positive value"
)

if value is None:
return QuantizationStrategy.TENSOR
if actorder and strategy != QuantizationStrategy.GROUP:
raise ValueError(
"Group quantization must be specified in order to apply "
"activation ordering"
)

return value
values["strategy"] = strategy
return values

def pytorch_dtype(self) -> torch.dtype:
if self.type == QuantizationType.FLOAT:
Expand Down
9 changes: 6 additions & 3 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def update_parameter_data(

:param module: layer containing the parameter to update
:param new_param_data: tensor to update parameter with
:param param_name:
:param param_name: name of layer parameter to update
"""
device = next(module.parameters()).device

Expand All @@ -99,8 +99,11 @@ def update_parameter_data(
offloaded = True

parameter = getattr(module, param_name, None)
dtype = parameter.dtype
parameter.data = new_param_data.to(device).to(dtype)
if parameter is not None:
dtype = parameter.dtype
parameter.data = new_param_data.to(device).to(dtype)
else:
setattr(module, param_name, new_param_data.to(device))

if offloaded:
prefix_dict = module._hf_hook.weights_map.dataset
Expand Down
2 changes: 2 additions & 0 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,7 @@ def is_quantization_param(name: str) -> bool:
return True
if name.endswith("zero_point"):
return True
if name.endswith("g_idx"):
return True

return False
Loading
Loading