Skip to content

Implement Missing Modifier Methods #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from typing import Optional

from pydantic import BaseModel
Expand All @@ -9,7 +10,7 @@
__all__ = ["Modifier"]


class Modifier(BaseModel, ModifierInterface):
class Modifier(BaseModel, ModifierInterface, ABC):
"""
A base class for all modifiers to inherit from.
Modifiers are used to modify the training process for a model.
Expand Down Expand Up @@ -224,15 +225,17 @@ def should_end(self, event: Event):
def on_initialize_structure(self, state: State, **kwargs):
"""
on_initialize_structure is called before the model is initialized
with the modifier structure. Must be implemented by the inheriting
modifier.
with the modifier structure.

TODO: Depreciate this function as part of the lifecycle

:param state: The current state of the model
:param kwargs: Additional arguments for initializing the structure
of the model in question
"""
raise NotImplementedError()
pass

@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
"""
on_initialize is called on modifier initialization and
Expand All @@ -255,7 +258,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
:return: True if the modifier was finalized successfully,
False otherwise
"""
raise NotImplementedError()
return True

def on_start(self, state: State, event: Event, **kwargs):
"""
Expand All @@ -266,7 +269,7 @@ def on_start(self, state: State, event: Event, **kwargs):
:param event: The event that triggered the start
:param kwargs: Additional arguments for starting the modifier
"""
raise NotImplementedError()
pass

def on_update(self, state: State, event: Event, **kwargs):
"""
Expand All @@ -278,7 +281,7 @@ def on_update(self, state: State, event: Event, **kwargs):
:param event: The event that triggered the update
:param kwargs: Additional arguments for updating the model
"""
raise NotImplementedError()
pass

def on_end(self, state: State, event: Event, **kwargs):
"""
Expand All @@ -289,7 +292,7 @@ def on_end(self, state: State, event: Event, **kwargs):
:param event: The event that triggered the end
:param kwargs: Additional arguments for ending the modifier
"""
raise NotImplementedError()
pass

def on_event(self, state: State, event: Event, **kwargs):
"""
Expand Down
26 changes: 1 addition & 25 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Module
from tqdm import tqdm

from llmcompressor.core.state import State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.utils.sgpt_wrapper import SparseGptWrapper
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
Expand Down Expand Up @@ -83,25 +83,12 @@ class SparseGPTModifier(Modifier):
prunem_: Optional[int] = None
compressible_layers_: Optional[List] = None

def on_initialize_structure(self, state: State, **kwargs):
"""
Initialize the structure of the model for compression.
This modifier does not modifiy the model structure, so this method
is a no-op.

:param state: session state storing input model and calibration data
"""
return True

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state

:param state: session state storing input model and calibration data
"""
if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)

if self.sparsity == 0.0:
raise ValueError(
"To use the SparseGPTModifier, target sparsity must be > 0.0"
Expand All @@ -121,17 +108,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

return True

def on_finalize(self, state: State, **kwargs):
"""
Nothing to do on finalize, on this level.
Quantization Modifier if any will be finalized in the subclass

:param state: session state storing input model and calibration data
:param kwargs: additional arguments
:return: True
"""
return True

def initialize_compression(
self,
model: Module,
Expand Down
3 changes: 0 additions & 3 deletions src/llmcompressor/modifiers/pruning/constant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class ConstantPruningModifier(Modifier, LayerParamMasking):
_save_masks: bool = False
_use_hooks: bool = False

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
if "save_masks" in kwargs:
self._save_masks = kwargs["save_masks"]
Expand Down
3 changes: 0 additions & 3 deletions src/llmcompressor/modifiers/pruning/magnitude/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
mask_creator_function_: MaskCreatorType = None
current_sparsity_: float = None

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
if self.apply_globally:
raise NotImplementedError("global pruning not implemented yet for PyTorch")
Expand Down
21 changes: 1 addition & 20 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Module
from tqdm import tqdm

from llmcompressor.core.state import State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.wanda.utils.wanda_wrapper import WandaWrapper
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
Expand Down Expand Up @@ -61,15 +61,6 @@ class WandaPruningModifier(Modifier):
prunen_: Optional[int] = None
prunem_: Optional[int] = None

def on_initialize_structure(self, state: State, **kwargs):
"""
This modifier does not alter the model structure.
This method is a no-op.

:param state: Unused, kept to conform to the parent method signature
:param kwargs: Unused, kept to conform to the parent method signature
"""

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the WANDA algorithm on the current state
Expand All @@ -91,16 +82,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:

return True

def on_finalize(self, state: State, **kwargs):
"""
Nothing to clean up for this module

:param state: Unused, kept to conform to the parent method signature
:param kwargs: Unused, kept to conform to the parent method signature
"""

return True

def compressible_layers(self) -> Dict:
"""
Retrieves the modules corresponding to a list of
Expand Down
4 changes: 3 additions & 1 deletion src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pydantic import Field, field_validator
from torch.nn import Module

from llmcompressor.core.state import State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier, ModifierFactory
from llmcompressor.modifiers.quantization.gptq.utils import (
GPTQWrapper,
Expand Down Expand Up @@ -130,6 +130,8 @@ def on_initialize_structure(self, state: State, **kwargs):
Check the model's quantization state matches that expected by this modifier,
adding a default quantization scheme if needed

TODO: Depreciate and fold into `on_initialize`

:param state: session state storing input model and calibration data
"""
quantization_already_active = qat_active(state.model)
Expand Down
9 changes: 0 additions & 9 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ class QuantizationModifier(Modifier):
calibration_dataloader_: Any = None
calibration_function_: Any = None

def on_initialize_structure(self, state: State, **kwargs):
pass

def on_initialize(self, state: State, **kwargs) -> bool:
if self.end and self.end != -1:
raise ValueError(
Expand All @@ -99,9 +96,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:

return True

def on_finalize(self, state: State, **kwargs) -> bool:
return True

def on_start(self, state: State, event: Event, **kwargs):
module = state.model
module.apply(set_module_for_calibration)
Expand All @@ -116,9 +110,6 @@ def on_end(self, state: State, event: Event, **kwargs):
module = state.model
module.apply(freeze_module_quantization)

def on_event(self, state: State, event: Event, **kwargs):
pass

def create_init_config(self) -> QuantizationConfig:
if self.targets is not None and isinstance(self.targets, str):
self.targets = [self.targets]
Expand Down
17 changes: 1 addition & 16 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from loguru import logger
from torch.nn import Module

from llmcompressor.core import Event, State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
Expand Down Expand Up @@ -102,9 +102,6 @@ class SmoothQuantModifier(Modifier):
resolved_mappings_: Optional[List] = None
scales_: Optional[Dict] = None

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run SmoothQuant on the given state
Expand Down Expand Up @@ -136,18 +133,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:

return True

def on_start(self, state: State, event: Event, **kwargs):
pass

def on_update(self, state: State, event: Event, **kwargs):
pass

def on_end(self, state: State, event: Event, **kwargs):
pass

def on_event(self, state: State, event: Event, **kwargs):
pass

def on_finalize(self, state: State, **kwargs) -> bool:
"""
Clean up by clearing the scale and mapping data
Expand Down
6 changes: 4 additions & 2 deletions tests/llmcompressor/recipe/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ def test_recipe_can_be_created_from_modifier_instances():


class A_FirstDummyModifier(Modifier):
pass
def on_initialize(self, *args, **kwargs) -> bool:
return True


class B_SecondDummyModifier(Modifier):
pass
def on_initialize(self, *args, **kwargs) -> bool:
return True


def test_create_recipe_string_from_modifiers_with_default_group_name():
Expand Down
Loading