Skip to content

TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models #1622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6541618
Create linear method
zhuohan123 Nov 3, 2023
a97ede8
Support llama with the new quantization scheme
zhuohan123 Nov 3, 2023
4671286
make awq work
zhuohan123 Nov 3, 2023
4579d67
Fix squeezellm
zhuohan123 Nov 3, 2023
4406447
Remove unused codes
zhuohan123 Nov 3, 2023
5a535e3
Fix mistral
zhuohan123 Nov 3, 2023
14e66f8
Fix format
zhuohan123 Nov 3, 2023
f464375
New weight loading method, working for llama
zhuohan123 Nov 8, 2023
a5852ef
Fix awq loading
zhuohan123 Nov 8, 2023
7bf933f
Fix squeeze llm
zhuohan123 Nov 8, 2023
8af8b60
fix quantization
zhuohan123 Nov 8, 2023
686dafb
new weight loader
zhuohan123 Nov 9, 2023
e474020
Fix vocab loading
zhuohan123 Nov 9, 2023
d107613
clean up llama loader
zhuohan123 Nov 9, 2023
d4aa8c9
fix awq
zhuohan123 Nov 9, 2023
f48381b
wip fix squeezellm
zhuohan123 Nov 9, 2023
c5a9f9c
fix squeeze llm
zhuohan123 Nov 9, 2023
92155da
fix weight loader for embedding
zhuohan123 Nov 9, 2023
e528dbc
fix
zhuohan123 Nov 9, 2023
772ab72
support mistral
zhuohan123 Nov 9, 2023
0a08e66
fix
zhuohan123 Nov 9, 2023
7d7aa4b
Fix aqulia
zhuohan123 Nov 9, 2023
1df5d6b
fix vocab loader
zhuohan123 Nov 9, 2023
93685f4
fix baichuan
zhuohan123 Nov 9, 2023
5f5ea90
fix bloom
zhuohan123 Nov 9, 2023
31af3ea
fix qwen
zhuohan123 Nov 10, 2023
68f5a3f
fix qwen
zhuohan123 Nov 10, 2023
4f68d07
fix opt
zhuohan123 Nov 10, 2023
23099e2
fix mpt
zhuohan123 Nov 10, 2023
d7d108d
fix internlm
zhuohan123 Nov 10, 2023
ed44156
fix gpt2
zhuohan123 Nov 10, 2023
a75dea1
fix gpt neox
zhuohan123 Nov 10, 2023
1f6ca33
fix gptj
zhuohan123 Nov 10, 2023
b118a2f
fix falcon
zhuohan123 Nov 10, 2023
fb595c7
clean up
zhuohan123 Nov 10, 2023
d5ffe88
Fix GPT Bigcode
zhuohan123 Nov 11, 2023
036bee8
Merge branch 'main' into refactor-quantization
zhuohan123 Nov 11, 2023
7acf443
Fix chatglm and yi models
zhuohan123 Nov 11, 2023
c33e0f0
format
zhuohan123 Nov 11, 2023
63af93c
Simplify code logic
zhuohan123 Nov 11, 2023
82c76b1
Simplify code
zhuohan123 Nov 11, 2023
f53469b
fix
zhuohan123 Nov 11, 2023
d4c0798
Add comment for linear.py
zhuohan123 Nov 11, 2023
f0e7f44
Add comments
zhuohan123 Nov 11, 2023
247252c
code cleanup
zhuohan123 Nov 11, 2023
dfb4a81
Add comment
zhuohan123 Nov 11, 2023
79a6a9a
Merge branch 'main' into refactor-quantization
zhuohan123 Nov 15, 2023
f750166
Fix review comments
zhuohan123 Nov 16, 2023
fd4f4d5
fix naming
zhuohan123 Nov 16, 2023
a7dd7f4
fix comment
zhuohan123 Nov 16, 2023
18898f7
rename
zhuohan123 Nov 16, 2023
2d01ce0
Fix issues in PR #1640
zhuohan123 Nov 16, 2023
241bfa8
Fix config
zhuohan123 Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
541 changes: 541 additions & 0 deletions vllm/model_executor/layers/linear.py

Large diffs are not rendered by default.

51 changes: 16 additions & 35 deletions vllm/model_executor/layers/quantized_linear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,22 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.layers.quantized_linear.squeezellm import (
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
from typing import Type

_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
"squeezellm":
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
}


class ParallelLinear:
from vllm.model_executor.layers.quantized_linear.awq import AWQConfig
from vllm.model_executor.layers.quantized_linear.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig

@classmethod
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return ColumnParallelLinear(*args, **kwargs)

name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
_QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig,
"squeezellm": SqueezeLLMConfig,
}

quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
return quant_linear_cls(*args, **kwargs)

@classmethod
def row(cls, *args, **kwargs) -> RowParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return RowParallelLinear(*args, **kwargs)
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_CONFIG_REGISTRY[quantization]

name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")

quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
return quant_linear_cls(*args, **kwargs)
__all__ = [
"QuantizationConfig",
"get_quantization_config",
]
194 changes: 123 additions & 71 deletions vllm/model_executor/layers/quantized_linear/awq.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,158 @@
from typing import Optional
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig


class AWQColumnParallelLinear(ColumnParallelLinear):
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.

def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.group_size == 0
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
Reference: https://arxiv.org/abs/2306.00978
"""

def apply_weights(
def __init__(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point

if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits

def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")

@classmethod
def get_name(cls) -> str:
return "awq"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]

@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75

@classmethod
def get_config_filenames(cls) -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]

class AWQRowParallelLinear(RowParallelLinear):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)

def create_weights(self, dtype: torch.dtype) -> None:
assert self.output_size % self.quant_config.pack_factor == 0
if self.input_size_per_partition % self.quant_config.group_size != 0:
def get_linear_method(self) -> "AWQLinearMethod":
return AWQLinearMethod(self)


class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.

Args:
quant_config: The AWQ quantization config.
"""

def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config

def create_weights(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
if input_size % self.quant_config.group_size != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")

qweight = Parameter(
torch.empty(
self.input_size_per_partition,
self.output_size // self.quant_config.pack_factor,
input_size,
output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
qzeros = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size // self.quant_config.pack_factor,
input_size // self.quant_config.group_size,
output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
set_weight_attrs(
qzeros, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size,
input_size // self.quant_config.group_size,
output_size,
device="cuda",
dtype=dtype,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
return {
"qweight": qweight,
"qzeros": qzeros,
"scales": scales,
}

def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
def apply_weights(self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
qzeros = weights["qzeros"]
scales = weights["scales"]
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
52 changes: 52 additions & 0 deletions vllm/model_executor/layers/quantized_linear/base_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any, Dict, List

import torch

from vllm.model_executor.layers.linear import LinearMethodBase


class QuantizationConfig:
"""Base class for quantization configs."""

@classmethod
def get_name(cls) -> str:
"""Name of the quantization method."""
raise NotImplementedError

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError

@classmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.

E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError

@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError

@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")

def get_linear_method(self) -> LinearMethodBase:
"""Get the linear method to use for the quantized linear layer."""
raise NotImplementedError
Loading