Skip to content

Replace LayerCompressor with HooksMixin #1038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
de278ce
extend remove_hooks to remove subsets
kylesayrs Dec 31, 2024
2754145
change arg type
kylesayrs Jan 1, 2025
b2e98c3
implement keep argument
kylesayrs Jan 1, 2025
3ab5323
use lazy value assignment rather than container, update docstring
kylesayrs Jan 1, 2025
b605db5
make keeps composable
kylesayrs Jan 1, 2025
59bdb66
squash
kylesayrs Jan 23, 2025
669965e
Merge remote-tracking branch 'origin' into kylesayrs/remove-layer-com…
kylesayrs Jan 23, 2025
54067ab
Merge branch 'main' into kylesayrs/hooks-mixin-keep
kylesayrs Jan 24, 2025
1b11b54
Merge branch 'main' into kylesayrs/hooks-mixin-remove-subsets
kylesayrs Jan 24, 2025
e3623cc
Merge branch 'main' into kylesayrs/hooks-mixin-keep
kylesayrs Jan 24, 2025
e12d4da
Merge branch 'main' into kylesayrs/remove-layer-compressor
kylesayrs Jan 24, 2025
f4f3d26
fix tests
kylesayrs Jan 27, 2025
46cc9bc
Merge remote-tracking branch 'origin' into kylesayrs/remove-layer-com…
kylesayrs Jan 27, 2025
1eea2ab
style
kylesayrs Jan 27, 2025
0f5c8ad
Merge branch 'main' into kylesayrs/remove-layer-compressor
kylesayrs Jan 29, 2025
eb83e67
update docstring
kylesayrs Jan 29, 2025
2d6e366
Merge branch 'kylesayrs/hooks-mixin-remove-subsets', remote-tracking …
kylesayrs Jan 29, 2025
5070615
fix merge
kylesayrs Jan 29, 2025
077c68e
Merge remote-tracking branch 'origin' into kylesayrs/remove-layer-com…
kylesayrs Jan 30, 2025
922ea62
Merge remote-tracking branch 'origin' into kylesayrs/hooks-mixin-keep
kylesayrs Jan 30, 2025
ecee510
ensure the random weight is not 24 sparse
kylesayrs Jan 31, 2025
54fd6fb
remove leftover comment
kylesayrs Jan 31, 2025
20c6c00
add ignore
kylesayrs Feb 3, 2025
ea4f2a2
update docstring with more examples
kylesayrs Feb 3, 2025
46ae8eb
Merge branch 'main' into kylesayrs/hooks-mixin-keep
dsikka Feb 4, 2025
a2934b3
use immutable default
kylesayrs Feb 5, 2025
da9df2e
Merge branch 'kylesayrs/hooks-mixin-keep' into kylesayrs/remove-layer…
kylesayrs Feb 5, 2025
5fb18d9
Merge branch 'main' into kylesayrs/remove-layer-compressor
dsikka Feb 5, 2025
ab00e52
Merge branch 'main' into kylesayrs/remove-layer-compressor
dsikka Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
389 changes: 115 additions & 274 deletions src/llmcompressor/modifiers/obcq/base.py

Large diffs are not rendered by default.

254 changes: 254 additions & 0 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import warnings
from collections import defaultdict
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy
import torch
from loguru import logger
from pydantic import Field, field_validator, model_validator

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
run_pipeline as run_layer_sequential,
)
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
)


class SparsityModifierMixin(HooksMixin):
# modifier arguments
sparsity: Optional[Union[float, List[float]]] = None
sparsity_profile: Optional[str] = None
mask_structure: str = "0:0"
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None

# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
sequential_targets: Union[str, List[str], None] = None
targets: Union[str, List[str], None] = None # alias sequential_targets
ignore: List[str] = Field(default_factory=list)

@field_validator("sequential_update", mode="before")
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
warnings.warn(
"`sequential_update=False` is no longer supported, setting "
"sequential_update=True",
DeprecationWarning,
)

return True

@field_validator("sparsity_profile", mode="before")
def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
if value is None:
return value

value = value.lower()

profile_options = ["owl"]
if value not in profile_options:
raise ValueError(f"Please choose profile from {profile_options}")

return value

@model_validator(mode="after")
def validate_model_after(model: "Modifier") -> "Modifier":
sparsity = model.sparsity
profile = model.sparsity_profile
owl_m = model.owl_m
owl_lmbda = model.owl_lmbda
mask_structure = model.mask_structure
targets = model.targets
sequential_targets = model.sequential_targets

if profile == "owl" and ((owl_m is not None) ^ (owl_lmbda is not None)):
raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither")

if profile != "owl" and (owl_m is not None or owl_lmbda is not None):
raise ValueError("Must provide both `owl_m` and `owl_lmbda`")

if owl_m is not None and sparsity is not None:
raise ValueError("Cannot provide both sparsity and owl parameters")

if targets is not None:
if sequential_targets is not None:
raise ValueError("Cannot use both `targets` and `sequential_targets`")
model.sequential_targets = targets
model.targets = None

model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)

return model

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state

:param state: session state storing input model and calibration data
"""
model = state.model
dataloader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)

# infer layer sparsities
if self.sparsity_profile == "owl":
logger.info(
"Using OWL to infer target layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_owl_layer_sparsity()

# get layers and validate sparsity
layers = get_layers(self.sequential_targets, model)
if isinstance(self.sparsity, (list, dict)) and len(layers) != len(
self.sparsity
):
raise ValueError(
f"{self.__repr_name__} was initialized with {len(self.sparsity)} "
f"sparsities values, but model only has {len(layers)} layers"
)

# register hooks
for index, (name, layer) in enumerate(layers.items()):
if isinstance(self.sparsity, dict):
layer_sparsity = self.sparsity[name]
elif isinstance(self.sparsity, list):
layer_sparsity = self.sparsity[index]
else:
layer_sparsity = self.sparsity

for name, module in get_prunable_layers(layer).items():
self._module_names[module] = name
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
input_names = dataloader.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self.ignore,
self,
)
return True

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn("Falling back to layer_sequential pipeline")
try:
run_layer_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self,
)
return True

except Exception as exception:
if isinstance(exception, TypeError):
warnings.warn(f"{model_name} fails layer-wise assumptions")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn(
"Falling back to basic pipeline, which requires extra memory and "
"may result in decreased accuracy"
)
run_basic(state.model, state.data.calib, self)
return True

return True

def _infer_sequential_targets(
self, model: torch.nn.Module
) -> Union[str, List[str]]:
if self.sequential_targets is None:
return get_no_split_params(model)
if isinstance(self.sequential_targets, str):
return [self.sequential_targets]
return self.sequential_targets

def _infer_owl_layer_sparsity(self, activations):
groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
groups[name] = torch.cat([item.flatten().cpu() for item in z])

del activations

outlier_ratios = {}
for group in groups:
threshold = torch.mean(groups[group]) * self.owl_m
outlier_ratios[group] = (
100 * (groups[group] > threshold).sum().item() / groups[group].numel()
)
outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios])
for k in outlier_ratios:
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
1
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
* self.owl_lmbda
* 2
)
outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios])
sparsities = {
k: 1
- (
outlier_ratios[k]
- numpy.mean(outlier_ratios_arr)
+ (1 - float(self.sparsity))
)
for k in outlier_ratios
}
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
for k in sparsities:
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]:
acts = defaultdict(int)

def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
nonlocal acts
if isinstance(input, tuple):
input = input[0]
acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt()

hooks = set(
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
for name, mod in model.named_modules()
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
)
with HooksMixin.disable_hooks(keep=hooks):
run_basic(model, dataloader)
self.remove_hooks(hooks)

return acts

def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]:
n, m = mask_structure.split(":")
return int(n), int(m)
Loading