Skip to content

Auto-Infer mappings Argument for SmoothQuantModifier Based on Model Architecture #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.smoothquant.utils import (
get_layer_mappings_from_architecture,
handle_mapping_resolution_errors,
)
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import get_layers, get_matching_layer

MINIMUM_SMOOTHING_SCALE = 1e-5

DEFAULT_SMOOTHQUANT_MAPPINGS = [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"],
]

__all__ = ["SmoothQuantScale", "SmoothQuantMapping", "SmoothQuantModifier"]

Expand Down Expand Up @@ -81,8 +81,9 @@ class SmoothQuantModifier(Modifier):
Each entry of the mapping list should be a list itself, in which the first
entry is a list of layers who share the same input activation (the one to be
to smoothed) and the second entry is the layer whose output is scaled to
achieve the smoothing.
If regex is used, it matches layers with the largest overlap in module name.
achieve the smoothing. If regex is used, it matches layers with the largest
overlap in module name. If not supplied the argument will be inferred from the
model architecture.
:param ignore: list of layers to ignore, even if they match a regex in mappings.
It should match the name of layers whose outputs are scaled to achieve
smoothing (the second entry of the mappings list).
Expand All @@ -93,7 +94,7 @@ class SmoothQuantModifier(Modifier):
"""

smoothing_strength: float = 0.5
mappings: List[Tuple] = DEFAULT_SMOOTHQUANT_MAPPINGS
mappings: Optional[List[Tuple]] = None
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None
Expand Down Expand Up @@ -121,6 +122,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
)

self.ignore = [] if not self.ignore else self.ignore
self.mappings = self._infer_mappings_from_model(state.model)
self.resolved_mappings_ = self._resolve_mappings(state.model)
self.scales_ = {}

Expand All @@ -147,6 +149,19 @@ def on_finalize(self, state: State, **kwargs) -> bool:

return True

def _infer_mappings_from_model(
self,
model: Module,
) -> List[Tuple]:
if self.mappings is not None:
return self.mappings

logger.info("No SmoothQuantModifier.mappings provided, inferring from model...")
return get_layer_mappings_from_architecture(
architecture=model.__class__.__name__
)

@handle_mapping_resolution_errors
def _resolve_mappings(self, model: Module) -> List:
"""
Transforms the list of activations to smooth and their corresponding weights
Expand Down
80 changes: 80 additions & 0 deletions src/llmcompressor/modifiers/smoothquant/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import functools
import pathlib
from collections import namedtuple
from typing import Dict, List, Tuple, Union

from loguru import logger

__all__ = [
"get_layer_mappings_from_architecture",
"MAPPINGS_REGISTRY",
"DEFAULT_SMOOTHQUANT_MAPPINGS",
]

LayerMapType = Tuple[Union[List[str], str], Union[List[str], str]]
LayerMap: LayerMapType = namedtuple("LayerMap", ["balance_layers", "smooth_layers"])

DEFAULT_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [
LayerMap(
balance_layers=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
smooth_layers="re:.*input_layernorm",
),
LayerMap(
balance_layers=["re:.*gate_proj", "re:.*up_proj"],
smooth_layers="re:.*post_attention_layernorm",
),
]
MIXTRAL_MAPPINGS: List[LayerMap] = [
LayerMap(
balance_layers=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
smooth_layers="re:.*input_layernorm",
),
LayerMap(
balance_layers=["re:.*gate"], smooth_layers="re:.*post_attention_layernorm"
),
]


# Registry of layer mappings for different architectures
# Add more mappings here
MAPPINGS_REGISTRY: Dict[str, List[LayerMap]] = {
"LlamaForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"MixtralForCausalLM": MIXTRAL_MAPPINGS,
"MistralForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Qwen2ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
}


def get_layer_mappings_from_architecture(architecture: str) -> List[LayerMap]:
"""
:param architecture: str: The architecture of the model
:return: list: The layer mappings for the given architecture
"""

if architecture not in MAPPINGS_REGISTRY:
logger.info(
f"Architecture {architecture} not found in mappings. "
f"Using default mappings: {DEFAULT_SMOOTHQUANT_MAPPINGS}"
)

return MAPPINGS_REGISTRY.get(architecture, DEFAULT_SMOOTHQUANT_MAPPINGS)


def handle_mapping_resolution_errors(func):
"""
Decorator to catch any errors that occur when resolving mappings and provide a
helpful error message to the user pointing them to the README
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as original_exception:
readme_location = pathlib.Path(__file__).parent / "README.md"
raise RuntimeError(
f"Error resolving mappings for given architecture."
f"Please refer to the README at {readme_location} for more information."
) from original_exception

return wrapper
6 changes: 1 addition & 5 deletions tests/llmcompressor/modifiers/smoothquant/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import pytest

from llmcompressor.modifiers.factory import ModifierFactory
from llmcompressor.modifiers.smoothquant.base import (
DEFAULT_SMOOTHQUANT_MAPPINGS,
SmoothQuantModifier,
)
from llmcompressor.modifiers.smoothquant.base import SmoothQuantModifier
from tests.llmcompressor.modifiers.conf import setup_modifier_factory


Expand Down Expand Up @@ -45,7 +42,6 @@ def setUp(self):
def test_defaults(self):
default_sq = SmoothQuantModifier()
assert default_sq.smoothing_strength == 0.5
assert default_sq.mappings == DEFAULT_SMOOTHQUANT_MAPPINGS

def test_override_defaults(self):
strength = 0.7
Expand Down
39 changes: 39 additions & 0 deletions tests/llmcompressor/modifiers/smoothquant/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unittest.mock import patch

import pytest

from llmcompressor.modifiers.smoothquant.utils import (
get_layer_mappings_from_architecture,
handle_mapping_resolution_errors,
)

smoothquant_utils = "llmcompressor.modifiers.smoothquant.utils"


@pytest.mark.unit
def test_handle_mapping_resolution_errors():
README_LOCATION = "llmcompressor/modifiers/smoothquant/README.md"

@handle_mapping_resolution_errors
def func_that_raises_exception():
raise ValueError("An error occurred")

with pytest.raises(RuntimeError) as excinfo:
func_that_raises_exception()

assert "Error resolving mappings for given architecture." in str(excinfo.value)
assert "Please refer to the README at" in str(excinfo.value)
assert README_LOCATION in str(excinfo.value)


@pytest.mark.unit
@patch(
f"{smoothquant_utils}.MAPPINGS_REGISTRY", {"arch1": "mapping1", "arch2": "mapping2"}
)
@patch(f"{smoothquant_utils}.DEFAULT_SMOOTHQUANT_MAPPINGS", "default_mapping")
def test_get_layer_mappings_from_architecture():
# Test when architecture is in MAPPINGS_REGISTRY
assert get_layer_mappings_from_architecture("arch1") == "mapping1"

# Test when architecture is not in MAPPINGS_REGISTRY
assert get_layer_mappings_from_architecture("arch3") == "default_mapping"
Loading