Skip to content

GPTQ Algorithm Cleanup #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 28, 2024
4 changes: 1 addition & 3 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(
state, freeze_quantization=False, **kwargs
)
self.quantization_modifier_.initialize(state, **kwargs)
if not self.quantize:
raise ValueError("To use the GPTQModifier, quantization must be enabled.")

Expand Down
181 changes: 90 additions & 91 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import time

from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
from compressed_tensors.quantization.observers import MemorylessObserver

from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD
from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper
from llmcompressor.utils import getattr_chain
from llmcompressor.utils.metric_logging import (
get_GPU_memory_usage,
get_layer_size_bytes,
Expand All @@ -21,6 +26,7 @@
from compressed_tensors.utils import (
get_offloaded_device,
is_module_offloaded,
update_parameter_data,
update_prefix_dict,
)
from loguru import logger
Expand Down Expand Up @@ -83,6 +89,12 @@ def compress(
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
"""
weight_quant_args = getattr_chain(
self.layer, "quantization_scheme.weights", None
)
if weight_quant_args is None:
logger.debug(f"Skipping unquantized layer {self.name}...")
return

if is_module_offloaded(self.layer):
self.layer._hf_hook.pre_forward(self.layer)
Expand All @@ -92,12 +104,14 @@ def compress(
W = self.layer.weight.data.clone()
from llmcompressor.pytorch.utils.helpers import tensor_sparsity

# standardize shape and dtype
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
elif isinstance(self.layer, transformers.Conv1D):
W.transpose_(0, 1)
W = W.float()

# sparsity mask
sparsity = tensor_sparsity(W)
preserve_zeros = sparsity >= SPARSITY_THRESHOLD
W_nz_mask = (
Expand All @@ -108,17 +122,13 @@ def compress(

tick = time.time()

if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
# fetch latest correct scale and ZP relevant for any changes
# such as activation reordering
from compressed_tensors.quantization import (
update_layer_weight_quant_params,
)

update_layer_weight_quant_params(self.layer)
# update quantization parameters for activation ordering
observer = MemorylessObserver(weight_quant_args)
scale, zero_point = observer(W)
update_parameter_data(self.layer, scale, "weight_scale")
update_parameter_data(self.layer, zero_point, "weight_zero_point")

# mask dead hessian values
dead = torch.diag(self.H) == 0
self.H[dead, dead] = 1
W[:, dead] = 0
Expand Down Expand Up @@ -152,61 +162,44 @@ def compress(
d = Hinv1[i, i]
q = w.clone()

if hasattr(self.layer, "weight_fake_quant"):
scale = self.layer.weight_fake_quant.scale
zero_point = self.layer.weight_fake_quant.zero_point
dtype = self.layer.weight_fake_quant.dtype
qscheme = self.layer.weight_fake_quant.qscheme
if qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
q = torch.quantize_per_tensor(q, scale, zero_point, dtype)
else:
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
q = torch.dequantize(q)
elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
)

strategy = quant_scheme.weights.strategy

if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
)
elif strategy == QuantizationStrategy.CHANNEL:
# TODO: for channelwise why isn't this just a 1d tensor?
q = fake_quantize(
q,
scale[:, 0],
zero_point[:, 0],
quant_scheme.weights,
)
else: # strategy == QuantizationStrategy.GROUP
# get the group index for the current column
column_idx = i1 + i
input_dim_group = (
column_idx // quant_scheme.weights.group_size
)

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL
q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)
# quantize column
strategy = weight_quant_args.strategy
if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
)
elif strategy == QuantizationStrategy.CHANNEL:
q = fake_quantize(
q,
scale[:, 0],
zero_point[:, 0],
weight_quant_args,
)
elif strategy == QuantizationStrategy.GROUP:
# get the group index for the current column
column_idx = i1 + i
input_dim_group = column_idx // weight_quant_args.group_size

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(weight_quant_args)
altered_qargs.strategy = QuantizationStrategy.CHANNEL
q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)
else:
raise ValueError(
"Quantization strategy is not supported for GPTQ: "
f"{strategy}"
)

# propagate column error
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2

Expand All @@ -218,6 +211,7 @@ def compress(
W1[:, i:] -= w1_err
Err1[:, i] = err1

# propagate block error
W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2

Expand All @@ -228,29 +222,10 @@ def compress(
W[:, i2:] -= w_err

if "METRIC" in logger._core.levels.keys():
logger.log("METRIC", "time %.2f" % (time.time() - tick))
logger.log("METRIC", "error %.2f" % torch.sum(Losses).item())

gpu_usage = get_GPU_memory_usage()
if len(gpu_usage) > 0:
for i in range(len(gpu_usage)):
perc = gpu_usage[i][0] * 100
total_memory = int(gpu_usage[i][1]) # GB
logger.log(
"METRIC",
(
f"GPU {i} | usage: {perc:.2f}%"
f" | total memory: {total_memory} GB"
),
)

logger.log(
"METRIC",
f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB",
)
self.log_metrics(tick, Losses)

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W.transpose_(0, 1)
W = W.reshape(final_shape).to(final_dtype)

# This is a bit hacky, but FSDP updates only work if we change the weight in
Expand All @@ -263,13 +238,37 @@ def compress(
update_prefix_dict(self.layer, "weight", self.layer.weight.to(device))
self.layer._hf_hook.post_forward(self.layer, None)

del W
del Losses
del diag

def free(self):
"""
Free the Hessian memory after the layer is complete
"""
delattr(self, "H")
super().free()

def log_metrics(self, start_tick: float, losses: torch.Tensor):
"""
Log metrics related to compression algorithm

:param start_tick: time when algorithm started"
:param losses: loss as result of algorithm
"""
logger.log("METRIC", "time %.2f" % (time.time() - start_tick))
logger.log("METRIC", "error %.2f" % torch.sum(losses).item())

gpu_usage = get_GPU_memory_usage()
if len(gpu_usage) > 0:
for i in range(len(gpu_usage)):
perc = gpu_usage[i][0] * 100
total_memory = int(gpu_usage[i][1]) # GB
logger.log(
"METRIC",
(
f"GPU {i} | usage: {perc:.2f}%"
f" | total memory: {total_memory} GB"
),
)

logger.log(
"METRIC",
f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB",
)
7 changes: 2 additions & 5 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def on_initialize_structure(self, state: State, **kwargs):
self._apply_modifier_to_model(module)
module.apply(freeze_module_quantization)

def on_initialize(
self, state: State, freeze_quantization: bool = True, **kwargs
) -> bool:
def on_initialize(self, state: State, **kwargs) -> bool:
if self.end and self.end != -1:
raise ValueError(
"end_epoch is disabled for QuantizationModifier and can only be set to"
Expand All @@ -96,8 +94,7 @@ def on_initialize(
self._check_token_distribution(
module, threshold=kwargs.get("min_tokens_per_module")
)
if freeze_quantization:
module.apply(freeze_module_quantization)
module.apply(freeze_module_quantization)

return True

Expand Down
33 changes: 33 additions & 0 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"parse_kwarg_tuples",
"is_package_available",
"import_from_path",
"getattr_chain",
]


Expand Down Expand Up @@ -1008,3 +1009,35 @@ def import_from_path(path: str) -> str:
return getattr(module, class_name)
except AttributeError:
raise AttributeError(f"Cannot find {class_name} in {_path}")


def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
"""
Chain multiple getattr calls, separated by `.`

:param obj: base object whose attributes are being retrieved
:param chain_str: attribute names separated by `.`
:param default: default value, throw error otherwise

"""
if len(args) >= 1:
has_default = True
default = args[0]
elif "default" in kwargs:
has_default = True
default = kwargs["default"]
else:
has_default = False

attr_names = chain_str.split(".")

res = obj
for attr_name in attr_names:
if not hasattr(res, attr_name):
if has_default:
return default
else:
raise AttributeError(f"{res} object has no attribute {attr_name}")
res = getattr(res, attr_name)

return res
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from collections import OrderedDict

import torch
from compressed_tensors.quantization.lifecycle.apply import apply_quantization_config
from compressed_tensors.quantization.quant_config import QuantizationConfig
from compressed_tensors.quantization.quant_scheme import preset_name_to_scheme
from loguru import logger

from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper


def test_ignore():
model = torch.nn.Sequential(
OrderedDict(
[
("first_layer", torch.nn.Linear(2, 3)),
("second_layer", torch.nn.Linear(3, 5)),
]
)
)

config = QuantizationConfig(
config_groups={"group_0": preset_name_to_scheme("W8A8", targets=["Linear"])},
ignore=["first_layer"],
)
apply_quantization_config(model, config)

messages = []
logger.add(lambda m: messages.append(m))

with torch.no_grad():
first_compressor = GPTQWrapper("first_layer", model.first_layer)
first_compressor.add_batch(torch.ones(2), None)
first_compressor.compress()

first_compressor = GPTQWrapper("second_layer", model.second_layer)
first_compressor.add_batch(torch.ones(3), None)
first_compressor.compress()

assert sum("Skipping unquantized layer first_layer" in m for m in messages) == 1
assert sum("Skipping unquantized layer second_layer" in m for m in messages) == 0
Loading
Loading