Skip to content

Commit 450ae2a

Browse files
committed
support bitsandbytes 8-bit and FP4 quantized models
1 parent 4ddc474 commit 450ae2a

File tree

5 files changed

+414
-186
lines changed

5 files changed

+414
-186
lines changed

tests/quantization/test_bitsandbytes.py

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,82 +5,108 @@
55
import pytest
66
import torch
77

8+
from tests.conftest import VllmRunner
89
from tests.quantization.utils import is_quant_method_supported
910
from vllm import SamplingParams
1011

11-
models_to_test = [
12+
models_4bit_to_test = [
1213
('huggyllama/llama-7b', 'quantize model inflight'),
13-
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
14+
('lllyasviel/omost-llama-3-8b-4bits',
15+
'read pre-quantized 4-bit NF4 model'),
16+
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
17+
'read pre-quantized 4-bit FP4 model'),
18+
]
19+
20+
models_8bit_to_test = [
21+
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
1422
]
1523

1624

1725
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
1826
reason='bitsandbytes is not supported on this GPU type.')
19-
@pytest.mark.parametrize("model_name, description", models_to_test)
20-
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
27+
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
28+
def test_load_4bit_bnb_model(vllm_runner, model_name, description) -> None:
2129
with vllm_runner(model_name,
2230
quantization='bitsandbytes',
2331
load_format='bitsandbytes',
2432
enforce_eager=True) as llm:
2533
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
2634

2735
# check the weights in MLP & SelfAttention are quantized to torch.uint8
28-
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
29-
assert qweight.dtype == torch.uint8, (
30-
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')
31-
32-
qweight = model.model.layers[0].mlp.down_proj.qweight
33-
assert qweight.dtype == torch.uint8, (
34-
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')
35-
36-
qweight = model.model.layers[0].self_attn.o_proj.qweight
37-
assert qweight.dtype == torch.uint8, (
38-
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')
39-
40-
qweight = model.model.layers[0].self_attn.qkv_proj.qweight
41-
assert qweight.dtype == torch.uint8, (
42-
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')
43-
44-
# some weights should not be quantized
45-
weight = model.lm_head.weight
46-
assert weight.dtype != torch.uint8, (
47-
'lm_head weight dtype should not be torch.uint8')
48-
49-
weight = model.model.embed_tokens.weight
50-
assert weight.dtype != torch.uint8, (
51-
'embed_tokens weight dtype should not be torch.uint8')
52-
53-
weight = model.model.layers[0].input_layernorm.weight
54-
assert weight.dtype != torch.uint8, (
55-
'input_layernorm weight dtype should not be torch.uint8')
56-
57-
weight = model.model.layers[0].post_attention_layernorm.weight
58-
assert weight.dtype != torch.uint8, (
59-
'input_layernorm weight dtype should not be torch.uint8')
60-
61-
# check the output of the model is expected
62-
sampling_params = SamplingParams(temperature=0.0,
63-
logprobs=1,
64-
prompt_logprobs=1,
65-
max_tokens=8)
66-
67-
prompts = ['That which does not kill us', 'To be or not to be,']
68-
expected_outputs = [
69-
'That which does not kill us makes us stronger.',
70-
'To be or not to be, that is the question.'
71-
]
72-
outputs = llm.generate(prompts, sampling_params=sampling_params)
73-
assert len(outputs) == len(prompts)
74-
75-
for index in range(len(outputs)):
76-
# compare the first line of the output
77-
actual_output = outputs[index][1][0].split('\n', 1)[0]
78-
expected_output = expected_outputs[index].split('\n', 1)[0]
79-
80-
assert len(actual_output) >= len(expected_output), (
81-
f'Actual {actual_output} should be larger than or equal to '
82-
f'expected {expected_output}')
83-
actual_output = actual_output[:len(expected_output)]
84-
85-
assert actual_output == expected_output, (
86-
f'Expected: {expected_output}, but got: {actual_output}')
36+
validate_model_weight_type(model, torch.uint8)
37+
38+
validate_model_output(llm)
39+
40+
41+
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
42+
reason='bitsandbytes is not supported on this GPU type.')
43+
@pytest.mark.parametrize("model_name, description", models_8bit_to_test)
44+
def test_load_8bit_bnb_model(vllm_runner, model_name, description) -> None:
45+
with vllm_runner(model_name,
46+
quantization='bitsandbytes',
47+
load_format='bitsandbytes',
48+
enforce_eager=True) as llm:
49+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
50+
51+
# check the weights in MLP & SelfAttention are quantized to torch.int8
52+
validate_model_weight_type(model, torch.int8)
53+
54+
validate_model_output(llm)
55+
56+
57+
def validate_model_weight_type(model, quantized_dtype=torch.uint8):
58+
# Check quantized weights
59+
quantized_layers = [('mlp.gate_up_proj.qweight',
60+
model.model.layers[0].mlp.gate_up_proj.qweight),
61+
('mlp.down_proj.qweight',
62+
model.model.layers[0].mlp.down_proj.qweight),
63+
('self_attn.o_proj.qweight',
64+
model.model.layers[0].self_attn.o_proj.qweight),
65+
('self_attn.qkv_proj.qweight',
66+
model.model.layers[0].self_attn.qkv_proj.qweight)]
67+
68+
for name, qweight in quantized_layers:
69+
assert qweight.dtype == quantized_dtype, (
70+
f'Expected {name} dtype {quantized_dtype} but got {qweight.dtype}')
71+
72+
# Check non-quantized weights
73+
non_quantized_layers = [
74+
('lm_head.weight', model.lm_head.weight),
75+
('embed_tokens.weight', model.model.embed_tokens.weight),
76+
('input_layernorm.weight',
77+
model.model.layers[0].input_layernorm.weight),
78+
('post_attention_layernorm.weight',
79+
model.model.layers[0].post_attention_layernorm.weight)
80+
]
81+
82+
for name, weight in non_quantized_layers:
83+
assert weight.dtype != quantized_dtype, (
84+
f'{name} dtype should not be {quantized_dtype}')
85+
86+
87+
def validate_model_output(llm: VllmRunner):
88+
sampling_params = SamplingParams(temperature=0.0,
89+
logprobs=1,
90+
prompt_logprobs=1,
91+
max_tokens=8)
92+
93+
prompts = ['That which does not kill us', 'To be or not to be,']
94+
expected_outputs = [
95+
'That which does not kill us makes us stronger.',
96+
'To be or not to be, that is the question.'
97+
]
98+
outputs = llm.generate(prompts, sampling_params=sampling_params)
99+
assert len(outputs) == len(prompts)
100+
101+
for index in range(len(outputs)):
102+
# compare the first line of the output
103+
actual_output = outputs[index][1][0].split('\n', 1)[0]
104+
expected_output = expected_outputs[index].split('\n', 1)[0]
105+
106+
assert len(actual_output) >= len(expected_output), (
107+
f'Actual {actual_output} should be larger than or equal to '
108+
f'expected {expected_output}')
109+
actual_output = actual_output[:len(expected_output)]
110+
111+
assert actual_output == expected_output, (
112+
f'Expected: {expected_output}, but got: {actual_output}')

vllm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ def verify_with_parallel_config(
326326
raise ValueError(
327327
"BitAndBytes quantization with TP or PP is not supported yet.")
328328

329+
# Remove the constraint after the bitsandbytes issue is fixed:
330+
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
329331
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
330332
logger.warning("CUDA graph is not supported on BitAndBytes yet, "
331333
"fallback to the eager mode.")

vllm/model_executor/layers/linear.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
3131
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
3232

3333

34-
def adjust_bitsandbytes_shard(param: Parameter,
35-
qkv_offsets: Dict[str, Tuple[int, int]],
36-
loaded_shard_id: str) -> Tuple[int, int]:
34+
def adjust_bitsandbytes_4bit_shard(param: Parameter,
35+
qkv_offsets: Dict[str, Tuple[int, int]],
36+
loaded_shard_id: str) -> Tuple[int, int]:
3737
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
3838

3939
total, _ = qkv_offsets["total"]
@@ -497,8 +497,9 @@ def weight_loader(self,
497497
shard_size, shard_offset = adjust_marlin_shard(
498498
param, shard_size, shard_offset)
499499

500-
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
501-
if use_bitsandbytes:
500+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
501+
False)
502+
if use_bitsandbytes_4bit:
502503
shard_size = loaded_weight.shape[output_dim]
503504
shard_offset = loaded_weight.shape[output_dim] * \
504505
loaded_shard_id
@@ -843,8 +844,9 @@ def weight_loader(self,
843844
shard_size, shard_offset = adjust_marlin_shard(
844845
param, shard_size, shard_offset)
845846

846-
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
847-
if use_bitsandbytes:
847+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
848+
False)
849+
if use_bitsandbytes_4bit:
848850
orig_qkv_offsets = {
849851
"q": (0, self.num_heads * self.head_size),
850852
"k": (self.num_heads * self.head_size,
@@ -856,7 +858,7 @@ def weight_loader(self,
856858
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
857859
0)
858860
}
859-
shard_size, shard_offset = adjust_bitsandbytes_shard(
861+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
860862
param, orig_qkv_offsets, loaded_shard_id)
861863

862864
if is_gguf_weight:

0 commit comments

Comments
 (0)