Skip to content

SqueezeLLM Support #1326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run_to_completion(profile: bool = False):
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
Expand Down
12 changes: 8 additions & 4 deletions csrc/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ torch::Tensor awq_gemm(
torch::Tensor _zeros,
int split_k_iters);

void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
}
148 changes: 148 additions & 0 deletions csrc/quantization/squeezellm/quant_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>

#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16

namespace vllm {
namespace squeezellm {

__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}

// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
const half2* __restrict__ vec,
const int* __restrict__ mat,
half2* __restrict__ mul,
const __half* __restrict__ lookup_table,
int height,
int width,
int batch,
int vec_height
) {

const int blockwidth2 = BLOCKWIDTH / 2;

int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;

__shared__ half2 blockvec[blockwidth2];

__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
int column_offset = col * 16;
for (int val = 0; val < 16; val += 1) {
int lut_index = column_offset + val;
deq2[val][off] = lookup_table[lut_index];
}

__half res;
half2 res2;
half2 tmp2;

int i;
int k;

unsigned int tmp1;
unsigned int lut_index1, lut_index2;

for (int b = 0; b < batch; ++b){
i = width * row + col;
res = __int2half_rd(0);
k = 0;

__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
__syncthreads();

while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);

res2 = {};
tmp2 = {};

lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 0], res2);

lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 1], res2);

lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 2], res2);

lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 3], res2);

res = __hadd(__hadd(res2.x, res2.y), res);

i += width;
k += 4;
}

// col%2 -> only set one of the two values
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}

atomicAdd(&mul[b * width / 2 + col / 2], res3);
}
}

} // namespace squeezellm
} // namespace vllm

// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
int height = mat.size(0);
int width = mat.size(1);

int batch = vec.size(0);
int vec_height = vec.size(1);

dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);

vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
(half2*) vec.data<at::Half>(),
mat.data_ptr<int>(),
(half2*) mul.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
height, width, batch, vec_height
);
}

#undef BLOCKWIDTH
#undef BLOCKHEIGHT4
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def get_torch_arch_list() -> Set[str]:
sources=[
"csrc/quantization.cpp",
"csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _verify_tokenizer_mode(self) -> None:
self.tokenizer_mode = tokenizer_mode

def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
supported_quantization = ["awq", "squeezellm"]
if self.quantization is None:
return
quantization = self.quantization.lower()
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def add_cli_args(
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights')
return parser
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/quantized_linear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.layers.quantized_linear.squeezellm import (
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)

_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
"squeezellm":
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
}


Expand Down
14 changes: 9 additions & 5 deletions vllm/model_executor/layers/quantized_linear/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
class AWQColumnParallelLinear(ColumnParallelLinear):

def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.weight_bits == 0
assert (self.output_size_per_partition %
self.quant_config.pack_factor == 0)
assert self.input_size % self.quant_config.group_size == 0
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size,
Expand Down Expand Up @@ -62,9 +64,11 @@ def apply_weights(
class AWQRowParallelLinear(RowParallelLinear):

def create_weights(self, dtype: torch.dtype) -> None:
assert (self.input_size_per_partition %
self.quant_config.weight_bits == 0)
assert self.output_size % self.quant_config.pack_factor == 0
if self.input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
Expand Down
84 changes: 84 additions & 0 deletions vllm/model_executor/layers/quantized_linear/squeezellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Optional

import torch
from torch.nn.parameter import Parameter

from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)


class SqueezeLLMColumnParallelLinear(ColumnParallelLinear):

def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size // self.quant_config.pack_factor,
self.output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size_per_partition,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)

def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)

if bias is not None:
out = out + bias
return out.reshape(out_shape)


class SqueezeLLMRowParallelLinear(RowParallelLinear):

def create_weights(self, dtype: torch.dtype) -> None:
if self.input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.pack_factor,
self.output_size,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)

def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)
return out.reshape(out_shape)
24 changes: 15 additions & 9 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,17 +314,21 @@ def load_weights(self,
load_format: str = "auto",
revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
col_weight_suffixes = ["weight"]
row_weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
col_weight_suffixes = (
self.quant_config.get_col_parallel_tensor_names())
row_weight_suffixes = (
self.quant_config.get_row_parallel_tensor_names())

column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
for suffix in col_weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
for suffix in row_weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")

tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -351,10 +355,10 @@ def load_weights(self,
if "rotary_emb.inv_freq" in name:
continue

is_packed = False
packed_dim = None
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
packed_dim = self.quant_config.get_packed_dim(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
Expand All @@ -368,9 +372,11 @@ def load_weights(self,
if is_transposed:
param = param.T

if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
if packed_dim is not None:
shard_dim = 0 if not is_transposed else 1
if packed_dim == shard_dim:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor

if weight_name in ["k_proj", "v_proj"]:
shard_id = tp_rank // num_kv_heads_replicas
Expand Down
Loading