Skip to content

Fp8 Quantization Support #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
from .base import Compressor
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .int_quantized import IntQuantizationCompressor
from .marlin_24 import Marlin24Compressor
from .model_compressor import ModelCompressor, map_modules_to_quant_args
from .naive_quantized import (
FloatQuantizationCompressor,
IntQuantizationCompressor,
QuantizationCompressor,
)
from .pack_quantized import PackedQuantizationCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
20 changes: 20 additions & 0 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import logging
import operator
import os
import re
from copy import deepcopy
from typing import Any, Dict, Optional, Union

import torch
import transformers
from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
QUANTIZATION_CONFIG_NAME,
Expand Down Expand Up @@ -236,6 +239,11 @@ def compress(
compressed_state_dict
)

# HACK: Override the dtype_byte_size function in transformers to
# support float8 types. Fix is posted upstream
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

return compressed_state_dict

def decompress(self, model_path: str, model: Module):
Expand Down Expand Up @@ -313,3 +321,15 @@ def map_modules_to_quant_args(model: Module) -> Dict:
quantized_modules_to_args[name] = submodule.quantization_scheme.weights

return quantized_modules_to_args


# HACK: Override the dtype_byte_size function in transformers to support float8 types
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
def new_dtype_byte_size(dtype):
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@
from tqdm import tqdm


__all__ = ["IntQuantizationCompressor"]
__all__ = [
"QuantizationCompressor",
"IntQuantizationCompressor",
"FloatQuantizationCompressor",
]

_LOGGER: logging.Logger = logging.getLogger(__name__)


@Compressor.register(name=CompressionFormat.int_quantized.value)
class IntQuantizationCompressor(Compressor):
@Compressor.register(name=CompressionFormat.naive_quantized.value)
class QuantizationCompressor(Compressor):
"""
Integer compression for quantized models. Weight of each quantized layer is
converted from its original float type to the format specified by the layer's
quantization scheme.
Implements naive compression for quantized models. Weight of each
quantized layer is converted from its original float type to the closest Pytorch
type to the type specified by the layer's QuantizationArgs.
"""

COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
Expand Down Expand Up @@ -77,7 +81,7 @@ def compress(
scale=scale,
zero_point=zp,
args=quant_args,
dtype=torch.int8,
dtype=quant_args.pytorch_dtype(),
)
elif name.endswith("zero_point"):
if torch.all(value == 0):
Expand Down Expand Up @@ -114,13 +118,27 @@ def decompress(
if "weight_scale" in weight_data:
zero_point = weight_data.get("weight_zero_point", None)
scale = weight_data["weight_scale"]
if zero_point is None:
# zero_point assumed to be 0 if not included in state_dict
zero_point = torch.zeros_like(scale)

decompressed = dequantize(
x_q=weight_data["weight"],
scale=scale,
zero_point=zero_point,
)
yield merge_names(weight_name, "weight"), decompressed


@Compressor.register(name=CompressionFormat.int_quantized.value)
class IntQuantizationCompressor(QuantizationCompressor):
"""
Alias for integer quantized models
"""

pass


@Compressor.register(name=CompressionFormat.float_quantized.value)
class FloatQuantizationCompressor(QuantizationCompressor):
"""
Alias for fp quantized models
"""

pass
4 changes: 0 additions & 4 deletions src/compressed_tensors/compressors/pack_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ def decompress(
if "weight_scale" in weight_data:
zero_point = weight_data.get("weight_zero_point", None)
scale = weight_data["weight_scale"]
if zero_point is None:
# zero_point assumed to be 0 if not included in state_dict
zero_point = torch.zeros_like(scale)

weight = weight_data["weight_packed"]
original_shape = torch.Size(weight_data["weight_shape"])
unpacked = unpack_4bit_ints(weight, original_shape)
Expand Down
2 changes: 2 additions & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
int_quantized = "int-quantized"
float_quantized = "float-quantized"
naive_quantized = "naive-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"

Expand Down
12 changes: 4 additions & 8 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,11 @@ def _load_quant_args_from_state_dict(
scale = getattr(module, scale_name, None)
zp = getattr(module, zp_name, None)
if scale is not None:
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}")
if state_dict_scale is not None:
scale.data = state_dict_scale.to(device).to(scale.dtype)
else:
scale.data = scale.data.to(device)

state_dict_scale = state_dict[f"{module_name}.{scale_name}"]
scale.data = state_dict_scale.to(device).to(scale.dtype)
if zp is not None:
zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
if zp_from_state is not None: # load the non-zero zero points
zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
zp.data = zp_from_state.to(device).to(zp.dtype)
else: # fill with zeros matching scale shape
zp.data = torch.zeros_like(scale, dtype=torch.int8).to(device)
zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device)
81 changes: 52 additions & 29 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from typing import Optional

import torch
from compressed_tensors.quantization.observers.helpers import calculate_range
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
round_to_quantized_type,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
Expand Down Expand Up @@ -80,8 +82,9 @@ def quantize(
def dequantize(
x_q: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: torch.Tensor = None,
args: QuantizationArgs = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
Expand All @@ -91,16 +94,9 @@ def dequantize(
:param scale: scale tensor
:param zero_point: zero point tensor
:param args: quantization args used to quantize x_q
:param dtype: optional dtype to cast the dequantized output to
:return: dequantized float tensor
"""
# ensure all tensors are on the same device
# assumes that the target device is the input
# tensor's device
if x_q.device != scale.device:
scale = scale.to(x_q.device)
if x_q.device != zero_point.device:
zero_point = zero_point.to(x_q.device)

if args is None:
if scale.ndim == 0 or scale.ndim == 1:
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
Expand All @@ -115,15 +111,20 @@ def dequantize(
else:
raise ValueError(
f"Could not infer a quantization strategy from scale with {scale.ndim} "
"dimmensions. Expected 0-2 dimmensions."
"dimmensions. Expected 0 or 2 dimmensions."
)

if dtype is None:
dtype = scale.dtype

return _process_quantization(
x=x_q,
scale=scale,
zero_point=zero_point,
args=args,
do_quantize=False,
do_dequantize=True,
dtype=dtype,
)


Expand Down Expand Up @@ -167,19 +168,13 @@ def _process_quantization(
do_quantize: bool = True,
do_dequantize: bool = True,
) -> torch.Tensor:
bit_range = 2**args.num_bits
q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
q_min = torch.tensor(-bit_range / 2, device=x.device)

q_min, q_max = calculate_range(args, x.device)
group_size = args.group_size

if args.strategy == QuantizationStrategy.GROUP:

if do_dequantize and not do_quantize:
# if dequantizing a quantized type infer the output type from the scale
output = torch.zeros_like(x, dtype=scale.dtype)
else:
output_dtype = dtype if dtype is not None else x.dtype
output = torch.zeros_like(x, dtype=output_dtype)
output_dtype = dtype if dtype is not None else x.dtype
output = torch.zeros_like(x).to(output_dtype)

# TODO: vectorize the for loop
# TODO: fix genetric assumption about the tensor size for computing group
Expand All @@ -189,7 +184,7 @@ def _process_quantization(
while scale.ndim < 2:
# pad scale and zero point dims for slicing
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None

columns = x.shape[1]
if columns >= group_size:
Expand All @@ -202,12 +197,18 @@ def _process_quantization(
# scale.shape should be [nchan, ndim]
# sc.shape should be [nchan, 1] after unsqueeze
sc = scale[:, i].view(-1, 1)
zp = zero_point[:, i].view(-1, 1)
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None

idx = i * group_size
if do_quantize:
output[:, idx : (idx + group_size)] = _quantize(
x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
x[:, idx : (idx + group_size)],
sc,
zp,
q_min,
q_max,
args,
dtype=dtype,
)
if do_dequantize:
input = (
Expand All @@ -219,7 +220,15 @@ def _process_quantization(

else: # covers channel, token and tensor strategies
if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
output = _quantize(
x,
scale,
zero_point,
q_min,
q_max,
args,
dtype=dtype,
)
if do_dequantize:
output = _dequantize(output if do_quantize else x, scale, zero_point)

Expand Down Expand Up @@ -313,14 +322,18 @@ def _quantize(
zero_point: torch.Tensor,
q_min: torch.Tensor,
q_max: torch.Tensor,
args: QuantizationArgs,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
quantized_value = torch.clamp(
torch.round(x / scale + zero_point),

scaled = x / scale + zero_point.to(x.dtype)
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
clamped_value = torch.clamp(
scaled,
q_min,
q_max,
)

quantized_value = round_to_quantized_type(clamped_value, args)
if dtype is not None:
quantized_value = quantized_value.to(dtype)

Expand All @@ -331,6 +344,16 @@ def _quantize(
def _dequantize(
x_q: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: torch.Tensor = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return (x_q - zero_point) * scale

dequant_value = x_q
if zero_point is not None:
dequant_value = dequant_value - zero_point.to(scale.dtype)
dequant_value = dequant_value.to(scale.dtype) * scale

if dtype is not None:
dequant_value = dequant_value.to(dtype)

return dequant_value
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def _initialize_scale_zero_point_observer(
)
module.register_parameter(f"{base_name}_scale", init_scale)

zp_dtype = quantization_args.pytorch_dtype()
init_zero_point = Parameter(
torch.empty(expected_shape, device=device, dtype=int),
torch.empty(expected_shape, device=device, dtype=zp_dtype),
requires_grad=False,
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
Loading