-
-
Notifications
You must be signed in to change notification settings - Fork 9.2k
TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models #1622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 46 commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
6541618
Create linear method
zhuohan123 a97ede8
Support llama with the new quantization scheme
zhuohan123 4671286
make awq work
zhuohan123 4579d67
Fix squeezellm
zhuohan123 4406447
Remove unused codes
zhuohan123 5a535e3
Fix mistral
zhuohan123 14e66f8
Fix format
zhuohan123 f464375
New weight loading method, working for llama
zhuohan123 a5852ef
Fix awq loading
zhuohan123 7bf933f
Fix squeeze llm
zhuohan123 8af8b60
fix quantization
zhuohan123 686dafb
new weight loader
zhuohan123 e474020
Fix vocab loading
zhuohan123 d107613
clean up llama loader
zhuohan123 d4aa8c9
fix awq
zhuohan123 f48381b
wip fix squeezellm
zhuohan123 c5a9f9c
fix squeeze llm
zhuohan123 92155da
fix weight loader for embedding
zhuohan123 e528dbc
fix
zhuohan123 772ab72
support mistral
zhuohan123 0a08e66
fix
zhuohan123 7d7aa4b
Fix aqulia
zhuohan123 1df5d6b
fix vocab loader
zhuohan123 93685f4
fix baichuan
zhuohan123 5f5ea90
fix bloom
zhuohan123 31af3ea
fix qwen
zhuohan123 68f5a3f
fix qwen
zhuohan123 4f68d07
fix opt
zhuohan123 23099e2
fix mpt
zhuohan123 d7d108d
fix internlm
zhuohan123 ed44156
fix gpt2
zhuohan123 a75dea1
fix gpt neox
zhuohan123 1f6ca33
fix gptj
zhuohan123 b118a2f
fix falcon
zhuohan123 fb595c7
clean up
zhuohan123 d5ffe88
Fix GPT Bigcode
zhuohan123 036bee8
Merge branch 'main' into refactor-quantization
zhuohan123 7acf443
Fix chatglm and yi models
zhuohan123 c33e0f0
format
zhuohan123 63af93c
Simplify code logic
zhuohan123 82c76b1
Simplify code
zhuohan123 f53469b
fix
zhuohan123 d4c0798
Add comment for linear.py
zhuohan123 f0e7f44
Add comments
zhuohan123 247252c
code cleanup
zhuohan123 dfb4a81
Add comment
zhuohan123 79a6a9a
Merge branch 'main' into refactor-quantization
zhuohan123 f750166
Fix review comments
zhuohan123 fd4f4d5
fix naming
zhuohan123 a7dd7f4
fix comment
zhuohan123 18898f7
rename
zhuohan123 2d01ce0
Fix issues in PR #1640
zhuohan123 241bfa8
Fix config
zhuohan123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,22 @@ | ||
from vllm.model_executor.layers.quantized_linear.awq import ( | ||
AWQColumnParallelLinear, AWQRowParallelLinear) | ||
from vllm.model_executor.layers.quantized_linear.squeezellm import ( | ||
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear) | ||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, | ||
RowParallelLinear) | ||
from typing import Type | ||
|
||
_QUANTIZED_LINEAR_REGISTRY = { | ||
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear), | ||
"squeezellm": | ||
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear), | ||
} | ||
|
||
|
||
class ParallelLinear: | ||
from vllm.model_executor.layers.quantized_linear.awq import AWQConfig | ||
from vllm.model_executor.layers.quantized_linear.squeezellm import SqueezeLLMConfig | ||
from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig | ||
|
||
@classmethod | ||
def column(cls, *args, **kwargs) -> ColumnParallelLinear: | ||
quant_config = kwargs.get("quant_config", None) | ||
if quant_config is None: | ||
return ColumnParallelLinear(*args, **kwargs) | ||
|
||
name = quant_config.get_name() | ||
if name not in _QUANTIZED_LINEAR_REGISTRY: | ||
raise ValueError(f"No quantized linear is found for {name}") | ||
_QUANTIZATION_CONFIG_REGISTRY = { | ||
"awq": AWQConfig, | ||
"squeezellm": SqueezeLLMConfig, | ||
} | ||
|
||
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0] | ||
return quant_linear_cls(*args, **kwargs) | ||
|
||
@classmethod | ||
def row(cls, *args, **kwargs) -> RowParallelLinear: | ||
quant_config = kwargs.get("quant_config", None) | ||
if quant_config is None: | ||
return RowParallelLinear(*args, **kwargs) | ||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: | ||
if quantization not in _QUANTIZATION_CONFIG_REGISTRY: | ||
raise ValueError(f"Invalid quantization method: {quantization}") | ||
return _QUANTIZATION_CONFIG_REGISTRY[quantization] | ||
|
||
name = quant_config.get_name() | ||
if name not in _QUANTIZED_LINEAR_REGISTRY: | ||
raise ValueError(f"No quantized linear is found for {name}") | ||
|
||
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1] | ||
return quant_linear_cls(*args, **kwargs) | ||
__all__ = [ | ||
"QuantizationConfig", | ||
"get_quantization_config", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,106 +1,158 @@ | ||
from typing import Optional | ||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
from torch.nn.parameter import Parameter | ||
|
||
from vllm import quantization_ops | ||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.linear import (LinearMethodBase, | ||
set_weight_attrs) | ||
from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig | ||
|
||
|
||
class AWQColumnParallelLinear(ColumnParallelLinear): | ||
class AWQConfig(QuantizationConfig): | ||
"""Config class for AWQ. | ||
|
||
def create_weights(self, dtype: torch.dtype) -> None: | ||
assert self.input_size % self.quant_config.group_size == 0 | ||
if self.output_size_per_partition % self.quant_config.pack_factor != 0: | ||
raise ValueError( | ||
"The tensor parallel size is not aligned with the quantized " | ||
"weight shape. Please use a different tensor parallel size.") | ||
self.qweight = Parameter( | ||
torch.empty( | ||
self.input_size, | ||
self.output_size_per_partition // | ||
self.quant_config.pack_factor, | ||
device="cuda", | ||
dtype=torch.int32, | ||
), | ||
requires_grad=False, | ||
) | ||
self.qzeros = Parameter( | ||
torch.empty( | ||
self.input_size // self.quant_config.group_size, | ||
self.output_size_per_partition // | ||
self.quant_config.pack_factor, | ||
device="cuda", | ||
dtype=torch.int32, | ||
), | ||
requires_grad=False, | ||
) | ||
self.scales = Parameter( | ||
torch.empty( | ||
self.input_size // self.quant_config.group_size, | ||
self.output_size_per_partition, | ||
device="cuda", | ||
dtype=dtype, | ||
), | ||
requires_grad=False, | ||
) | ||
Reference: https://arxiv.org/abs/2306.00978 | ||
""" | ||
|
||
def apply_weights( | ||
def __init__( | ||
self, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
pack_factor = self.quant_config.pack_factor | ||
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, )) | ||
reshaped_x = x.reshape(-1, x.shape[-1]) | ||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, | ||
self.qzeros, pack_factor) | ||
if bias is not None: | ||
out = out + bias | ||
return out.reshape(out_shape) | ||
weight_bits: int, | ||
group_size: int, | ||
zero_point: bool, | ||
) -> None: | ||
self.weight_bits = weight_bits | ||
self.group_size = group_size | ||
self.zero_point = zero_point | ||
|
||
if self.weight_bits != 4: | ||
raise ValueError( | ||
"Currently, only 4-bit weight quantization is supported for " | ||
f"AWQ, but got {self.weight_bits} bits.") | ||
self.pack_factor = 32 // self.weight_bits | ||
|
||
def __repr__(self) -> str: | ||
return (f"AWQConfig(weight_bits={self.weight_bits}, " | ||
f"group_size={self.group_size}, " | ||
f"zero_point={self.zero_point})") | ||
|
||
@classmethod | ||
def get_name(cls) -> str: | ||
return "awq" | ||
|
||
@classmethod | ||
def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||
return [torch.half] | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
# The AWQ kernel only supports Turing or newer GPUs. | ||
return 75 | ||
|
||
@classmethod | ||
def get_config_filenames(cls) -> List[str]: | ||
return [ | ||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq | ||
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long | ||
] | ||
|
||
class AWQRowParallelLinear(RowParallelLinear): | ||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": | ||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) | ||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) | ||
zero_point = cls.get_from_keys(config, ["zero_point"]) | ||
return cls(weight_bits, group_size, zero_point) | ||
|
||
def create_weights(self, dtype: torch.dtype) -> None: | ||
assert self.output_size % self.quant_config.pack_factor == 0 | ||
if self.input_size_per_partition % self.quant_config.group_size != 0: | ||
def get_linear_method(self) -> "AWQLinearMethod": | ||
return AWQLinearMethod(self) | ||
|
||
|
||
class AWQLinearMethod(LinearMethodBase): | ||
"""Linear method for AWQ. | ||
|
||
Args: | ||
quant_config: The AWQ quantization config. | ||
""" | ||
|
||
def __init__(self, quant_config: AWQConfig): | ||
self.quant_config = quant_config | ||
|
||
def create_weights(self, input_size: int, output_size: int, | ||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: | ||
if input_size % self.quant_config.group_size != 0: | ||
raise ValueError( | ||
"The tensor parallel size is not aligned with the quantized " | ||
"weight shape. Please use a different tensor parallel size.") | ||
self.qweight = Parameter( | ||
"The input size is not aligned with the quantized " | ||
"weight shape. This can be caused by too large " | ||
"tensor parallel size.") | ||
if output_size % self.quant_config.pack_factor != 0: | ||
raise ValueError( | ||
"The output size is not aligned with the quantized " | ||
"weight shape. This can be caused by too large " | ||
"tensor parallel size.") | ||
|
||
qweight = Parameter( | ||
torch.empty( | ||
self.input_size_per_partition, | ||
self.output_size // self.quant_config.pack_factor, | ||
input_size, | ||
output_size // self.quant_config.pack_factor, | ||
device="cuda", | ||
dtype=torch.int32, | ||
), | ||
requires_grad=False, | ||
) | ||
self.qzeros = Parameter( | ||
set_weight_attrs( | ||
qweight, { | ||
"input_dim": 0, | ||
"output_dim": 1, | ||
"packed_dim": 1, | ||
"pack_factor": self.quant_config.pack_factor, | ||
}) | ||
qzeros = Parameter( | ||
torch.empty( | ||
self.input_size_per_partition // self.quant_config.group_size, | ||
self.output_size // self.quant_config.pack_factor, | ||
input_size // self.quant_config.group_size, | ||
output_size // self.quant_config.pack_factor, | ||
device="cuda", | ||
dtype=torch.int32, | ||
), | ||
requires_grad=False, | ||
) | ||
self.scales = Parameter( | ||
set_weight_attrs( | ||
qzeros, { | ||
"input_dim": 0, | ||
"output_dim": 1, | ||
"packed_dim": 1, | ||
"pack_factor": self.quant_config.pack_factor, | ||
}) | ||
scales = Parameter( | ||
torch.empty( | ||
self.input_size_per_partition // self.quant_config.group_size, | ||
self.output_size, | ||
input_size // self.quant_config.group_size, | ||
output_size, | ||
device="cuda", | ||
dtype=dtype, | ||
dtype=params_dtype, | ||
), | ||
requires_grad=False, | ||
) | ||
set_weight_attrs(scales, { | ||
"input_dim": 0, | ||
"output_dim": 1, | ||
}) | ||
return { | ||
"qweight": qweight, | ||
"qzeros": qzeros, | ||
"scales": scales, | ||
} | ||
|
||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: | ||
def apply_weights(self, | ||
weights: Dict[str, torch.Tensor], | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
qweight = weights["qweight"] | ||
qzeros = weights["qzeros"] | ||
scales = weights["scales"] | ||
pack_factor = self.quant_config.pack_factor | ||
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, )) | ||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) | ||
reshaped_x = x.reshape(-1, x.shape[-1]) | ||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, | ||
self.qzeros, pack_factor) | ||
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros, | ||
pack_factor) | ||
if bias is not None: | ||
out = out + bias | ||
return out.reshape(out_shape) |
52 changes: 52 additions & 0 deletions
52
vllm/model_executor/layers/quantized_linear/base_config.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Any, Dict, List | ||
|
||
import torch | ||
|
||
from vllm.model_executor.layers.linear import LinearMethodBase | ||
|
||
|
||
class QuantizationConfig: | ||
"""Base class for quantization configs.""" | ||
|
||
@classmethod | ||
def get_name(cls) -> str: | ||
"""Name of the quantization method.""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||
"""List of supported activation dtypes.""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
"""Minimum GPU capability to support the quantization method. | ||
|
||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere. | ||
This requirement is due to the custom CUDA kernels used by the | ||
quantization method. | ||
""" | ||
raise NotImplementedError | ||
zhuohan123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def get_config_filenames(cls) -> List[str]: | ||
"""List of filenames to search for in the model directory.""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": | ||
"""Create a config class from the model's quantization config.""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: | ||
"""Get a value from the model's quantization config.""" | ||
for key in keys: | ||
if key in config: | ||
return config[key] | ||
raise ValueError(f"Cannot find any of {keys} in the model's " | ||
"quantization config.") | ||
|
||
def get_linear_method(self) -> LinearMethodBase: | ||
"""Get the linear method to use for the quantized linear layer.""" | ||
raise NotImplementedError |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.