Skip to content

Add a generic wrap_hf_model_class utility to support VLMs #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/quantization_w8a8_fp8/llama3.2_vision_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from transformers import AutoProcessor, MllamaForConditionalGeneration

from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot, wrap_hf_model_class

MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"

# Load model.
model_class = wrap_hf_model_class(MllamaForConditionalGeneration)
model = model_class.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
processor = AutoProcessor.from_pretrained(MODEL_ID)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"],
)

# Apply quantization and save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
oneshot(model=model, recipe=recipe, output_dir=SAVE_DIR)
processor.save_pretrained(SAVE_DIR)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
33 changes: 33 additions & 0 deletions examples/quantization_w8a8_fp8/llava1.5_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from transformers import AutoProcessor, LlavaForConditionalGeneration

from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot, wrap_hf_model_class

MODEL_ID = "llava-hf/llava-1.5-7b-hf"

# Load model.
model_class = wrap_hf_model_class(LlavaForConditionalGeneration)
model = model_class.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
processor = AutoProcessor.from_pretrained(MODEL_ID)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_tower.*"],
)

# Apply quantization and save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
oneshot(model=model, recipe=recipe, output_dir=SAVE_DIR)
processor.save_pretrained(SAVE_DIR)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
33 changes: 33 additions & 0 deletions examples/quantization_w8a8_fp8/qwen2vl_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot, wrap_hf_model_class

MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"

# Load model.
model_class = wrap_hf_model_class(Qwen2VLForConditionalGeneration)
model = model_class.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
processor = AutoProcessor.from_pretrained(MODEL_ID)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["re:.*lm_head", "re:visual.*"],
)

# Apply quantization and save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
oneshot(model=model, recipe=recipe, output_dir=SAVE_DIR)
processor.save_pretrained(SAVE_DIR)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
# isort: skip_file
# (import order matters for circular import avoidance)
from .utils import *
from .sparsification import SparseAutoModel, SparseAutoModelForCausalLM
from .sparsification import SparseAutoModel, SparseAutoModelForCausalLM, wrap_hf_model_class
from .finetune import *
55 changes: 35 additions & 20 deletions src/llmcompressor/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,26 @@
resolve_recipe,
)

__all__ = ["SparseAutoModel", "SparseAutoModelForCausalLM", "get_shared_tokenizer_src"]
__all__ = [
"wrap_hf_model_class",
"SparseAutoModel",
"SparseAutoModelForCausalLM",
"get_shared_tokenizer_src"
]


class SparseAutoModelForCausalLM(AutoModelForCausalLM):
def wrap_hf_model_class(hf_model_class: PreTrainedModel) -> PreTrainedModel:
"""
LLM Compressor wrapper for the AutoModelForCausalLM class
Its lifecycle is defined as follows:
1. If pretrained_model_name_or_path is a HuggingFace stub
the appropriate HuggingFace model will be downloaded
(if required) and the path to the deployment directory
of the model will be retrieved
2. The original model definition will be loaded, without
the model weights
3. The appropriate recipe will be applied to the model
if requested or required
4. The appropriate set of weights will be loaded into the model
Wrap a HF PreTrainedModel class to
1. Decompress a compressed model
2. Initialize any saved recipes
3. Wrap the `save_pretrained` method to allow saving as a compressed model

:param hf_model_class: Model class to wrap
:return: Wrapped model class
"""

# Add the from_pretrained class method
@classmethod
def from_pretrained(
cls,
Expand All @@ -51,15 +53,14 @@ def from_pretrained(
**kwargs,
) -> PreTrainedModel:
"""
A wrapper around the AutoModelForCausalLM.from_pretrained method
A wrapper around the PreTrainedModel.from_pretrained method

:param pretrained_model_name_or_path: the name of or path to the model to load
:param recipe: the path to the recipe file to apply to the model. Can be a
string or Path object. If None, a recipe will be searched for in the
pretrained_model_name_or_path directory and applied if found
:return the created model for causal language modeling
"""

def skip(*args, **kwargs):
pass

Expand Down Expand Up @@ -91,15 +92,15 @@ def skip(*args, **kwargs):
transformers_logger.setLevel(level=logging.ERROR)

if kwargs.get("trust_remote_code"):
# By artifically aliasing
# class name SparseAutoModelForCausallLM to
# AutoModelForCausalLM we can "trick" the
# By artifically aliasing the
# class name to the
# hf_model_class we can "trick" the
# `from_pretrained` method into properly
# resolving the logic when
# (has_remote_code and trust_remote_code) == True
cls.__name__ = AutoModelForCausalLM.__name__
cls.__name__ = hf_model_class.__name__

model = super(AutoModelForCausalLM, cls).from_pretrained(
model = super(hf_model_class, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)

Expand Down Expand Up @@ -153,6 +154,20 @@ def skip(*args, **kwargs):

return model

# Add the wrapped methods to the new class
wrapped_model_class = type(
hf_model_class.__name__,
(hf_model_class,),
{
"from_pretrained": from_pretrained
}
)

return wrapped_model_class


SparseAutoModelForCausalLM = wrap_hf_model_class(AutoModelForCausalLM)


class SparseAutoModel:
"""
Expand Down
Loading