Skip to content

Add FP8 Support #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,4 @@ nm_temp_test_logs/*
sparse_logs/*
wandb/
output_finetune/
env_log.json
49 changes: 49 additions & 0 deletions examples/quantization/llama7b_fp8_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationType,
)
from datasets import load_dataset
from transformers import AutoTokenizer

from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot

model_stub = "meta-llama/Meta-Llama-3-8B-Instruct"
output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed"
num_calibration_samples = 512

tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token


def preprocess(batch):
text = tokenizer.apply_chat_template(batch["messages"], tokenize=False)
tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048)
return tokenized


ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
examples = ds.map(preprocess, remove_columns=ds.column_names)

quant_args = QuantizationArgs(type=QuantizationType.FLOAT)
quant_scheme = QuantizationScheme(
weights=quant_args, input_activations=quant_args, targets=["Linear"]
)
recipe = QuantizationModifier(
config_groups={"group_0": quant_scheme}, ignore=["lm_head"]
)

model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

oneshot(
model=model,
dataset=examples,
recipe=recipe,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
save_compressed=True,
)
20 changes: 12 additions & 8 deletions src/llmcompressor/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def infer_quantization_format(
return quantization_format

if save_compressed:
quant_depths = _get_quant_depths(model)
if quant_depths == [4]: # save packed if everything is int4
quant_types = _get_quant_types(model)
if quant_types == ["int4"]: # save packed if everything is int4
return CompressionFormat.pack_quantized
elif quant_types == ["float8"]:
return CompressionFormat.float_quantized

# otherwise just quantize to int8
return CompressionFormat.int_quantized
Expand All @@ -56,17 +58,19 @@ def infer_quantization_format(
return None


def _get_quant_depths(model):
def _get_quant_types(model):
"""
Gets a list of all the quantized bit depths present in model
Gets a list of all the quantized types present in model
"""
quant_depths = []
quant_info = []
for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
weight_scheme = submodule.quantization_scheme.weights
if weight_scheme is not None:
weight_bit_depth = weight_scheme.num_bits
if weight_bit_depth not in quant_depths:
quant_depths.append(weight_bit_depth)
weight_type = weight_scheme.type
weight_info = f"{weight_type}{weight_bit_depth}"
if weight_info not in quant_info:
quant_info.append(weight_info)

return quant_depths
return quant_info
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml"
ppl_threshold: 20
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "commit"
test_type: "regression"
model_stub: "Xenova/llama2.c-stories15M"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml"
ppl_threshold: 21000
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: "float"
symmetric: true
strategy: channel
targets: ["Linear"]
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_quantization_reload(self):
n_scale, n_zp, n_weight = reloaded_weights[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype == torch.int8
assert o_zp.dtype == n_zp.dtype
assert torch.equal(o_zp, n_zp)

# we don't expect an exact match here because o_weight still has the
Expand All @@ -119,7 +119,7 @@ def test_quantization_reload(self):
n_scale, n_zp = reloaded_inputs[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype == torch.int8
assert o_zp.dtype == n_zp.dtype
assert torch.equal(o_zp, n_zp)

def _get_dataloader(self, data_args, tokenizer):
Expand Down
Loading