Skip to content

GPTQ Algorithm Cleanup #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 28, 2024
4 changes: 1 addition & 3 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(
state, freeze_quantization=False, **kwargs
)
self.quantization_modifier_.initialize(state, **kwargs)
if not self.quantize:
raise ValueError("To use the GPTQModifier, quantization must be enabled.")

Expand Down
185 changes: 98 additions & 87 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import time

from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
from compressed_tensors.quantization.observers import MemorylessObserver

from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD
from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper
from llmcompressor.utils import getattr_chain
from llmcompressor.utils.metric_logging import (
get_GPU_memory_usage,
get_layer_size_bytes,
Expand All @@ -21,6 +26,7 @@
from compressed_tensors.utils import (
get_offloaded_device,
is_module_offloaded,
update_parameter_data,
update_prefix_dict,
)
from loguru import logger
Expand Down Expand Up @@ -83,6 +89,13 @@ def compress(
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
"""
weight_quant_args = getattr_chain(
self.layer, "quantization_scheme.weights", None
)
weight_fake_quant = getattr(self.layer, "weight_fake_quant", None)
if weight_quant_args is None and weight_fake_quant is None:
logger.debug("Skipping layer GPTQ quantization...")
return

if is_module_offloaded(self.layer):
self.layer._hf_hook.pre_forward(self.layer)
Expand All @@ -92,12 +105,14 @@ def compress(
W = self.layer.weight.data.clone()
from llmcompressor.pytorch.utils.helpers import tensor_sparsity

# standardize shape and dtype
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
elif isinstance(self.layer, transformers.Conv1D):
W.transpose_(0, 1)
W = W.float()

# sparsity mask
sparsity = tensor_sparsity(W)
preserve_zeros = sparsity >= SPARSITY_THRESHOLD
W_nz_mask = (
Expand All @@ -108,17 +123,20 @@ def compress(

tick = time.time()

if hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
# fetch latest correct scale and ZP relevant for any changes
# such as activation reordering
from compressed_tensors.quantization import (
update_layer_weight_quant_params,
)

update_layer_weight_quant_params(self.layer)

# compute quantization parameters
if weight_fake_quant is not None:
scale = weight_fake_quant.scale
zero_point = weight_fake_quant.zero_point
dtype = weight_fake_quant.dtype
tensor_scheme = weight_fake_quant.qscheme in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]
else: # weight_quant_args is not None
observer = MemorylessObserver(weight_quant_args)
scale, zero_point = observer(W)

# mask dead hessian values
dead = torch.diag(self.H) == 0
self.H[dead, dead] = 1
W[:, dead] = 0
Expand Down Expand Up @@ -152,61 +170,46 @@ def compress(
d = Hinv1[i, i]
q = w.clone()

if hasattr(self.layer, "weight_fake_quant"):
scale = self.layer.weight_fake_quant.scale
zero_point = self.layer.weight_fake_quant.zero_point
dtype = self.layer.weight_fake_quant.dtype
qscheme = self.layer.weight_fake_quant.qscheme
if qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
if weight_fake_quant is not None:
if tensor_scheme:
q = torch.quantize_per_tensor(q, scale, zero_point, dtype)
else:
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
q = torch.dequantize(q)
elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
else: # weight_quant_args is not None:
strategy = weight_quant_args.strategy
if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
)
elif strategy == QuantizationStrategy.CHANNEL:
# TODO: for channelwise why isn't this just a 1d tensor?
q = fake_quantize(
q,
scale[:, 0],
zero_point[:, 0],
weight_quant_args,
)
else: # strategy == QuantizationStrategy.GROUP
# get the group index for the current column
column_idx = i1 + i
input_dim_group = column_idx // weight_quant_args.group_size

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(weight_quant_args)
altered_qargs.strategy = QuantizationStrategy.CHANNEL
q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)

strategy = quant_scheme.weights.strategy

if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
)
elif strategy == QuantizationStrategy.CHANNEL:
# TODO: for channelwise why isn't this just a 1d tensor?
q = fake_quantize(
q,
scale[:, 0],
zero_point[:, 0],
quant_scheme.weights,
)
else: # strategy == QuantizationStrategy.GROUP
# get the group index for the current column
column_idx = i1 + i
input_dim_group = (
column_idx // quant_scheme.weights.group_size
)

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL
q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)

# propagate column error
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2

Expand All @@ -218,6 +221,7 @@ def compress(
W1[:, i:] -= w1_err
Err1[:, i] = err1

# propagate block error
W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2

Expand All @@ -228,48 +232,55 @@ def compress(
W[:, i2:] -= w_err

if "METRIC" in logger._core.levels.keys():
logger.log("METRIC", "time %.2f" % (time.time() - tick))
logger.log("METRIC", "error %.2f" % torch.sum(Losses).item())

gpu_usage = get_GPU_memory_usage()
if len(gpu_usage) > 0:
for i in range(len(gpu_usage)):
perc = gpu_usage[i][0] * 100
total_memory = int(gpu_usage[i][1]) # GB
logger.log(
"METRIC",
(
f"GPU {i} | usage: {perc:.2f}%"
f" | total memory: {total_memory} GB"
),
)

logger.log(
"METRIC",
f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB",
)
self.log_metrics(tick, Losses)

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W.transpose_(0, 1)
W = W.reshape(final_shape).to(final_dtype)

# This is a bit hacky, but FSDP updates only work if we change the weight in
# place, clone() or direct assignment won't work
self.layer.weight -= self.layer.weight
self.layer.weight += W
update_parameter_data(self.layer, scale, "weight_scale")
update_parameter_data(self.layer, zero_point, "weight_zero_point")

if is_module_offloaded(self.layer):
device = get_offloaded_device(self.layer)
update_prefix_dict(self.layer, "weight", self.layer.weight.to(device))
self.layer._hf_hook.post_forward(self.layer, None)

del W
del Losses
del diag

def free(self):
"""
Free the Hessian memory after the layer is complete
"""
delattr(self, "H")
super().free()

def log_metrics(self, start_tick: float, losses: torch.Tensor):
"""
Log metrics related to compression algorithm

:param start_tick: time when algorithm started"
:param losses: loss as result of algorithm
"""
logger.log("METRIC", "time %.2f" % (time.time() - start_tick))
logger.log("METRIC", "error %.2f" % torch.sum(losses).item())

gpu_usage = get_GPU_memory_usage()
if len(gpu_usage) > 0:
for i in range(len(gpu_usage)):
perc = gpu_usage[i][0] * 100
total_memory = int(gpu_usage[i][1]) # GB
logger.log(
"METRIC",
(
f"GPU {i} | usage: {perc:.2f}%"
f" | total memory: {total_memory} GB"
),
)

logger.log(
"METRIC",
f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB",
)
33 changes: 33 additions & 0 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"parse_kwarg_tuples",
"is_package_available",
"import_from_path",
"getattr_chain",
]


Expand Down Expand Up @@ -1008,3 +1009,35 @@ def import_from_path(path: str) -> str:
return getattr(module, class_name)
except AttributeError:
raise AttributeError(f"Cannot find {class_name} in {_path}")


def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
"""
Chain multiple getattr calls, separated by `.`

:param obj: base object whose attributes are being retrieved
:param chain_str: attribute names separated by `.`
:param default: default value, throw error otherwise

"""
if len(args) >= 1:
has_default = True
default = args[0]
elif "default" in kwargs:
has_default = True
default = kwargs["default"]
else:
has_default = False

attr_names = chain_str.split(".")

res = obj
for attr_name in attr_names:
if not hasattr(res, attr_name):
if has_default:
return default
else:
raise AttributeError(f"{res} object has no attribute {attr_name}")
res = getattr(res, attr_name)

return res
Loading