Skip to content

[Bugfix] Fix kv_cache_dtype=fp8 without scales for FP8 checkpoints #6761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,20 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):

@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
def test_load_fp16_model(vllm_runner) -> None:
with vllm_runner("facebook/opt-125m", quantization="fp8") as llm:
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
with vllm_runner("facebook/opt-125m",
quantization="fp8",
kv_cache_dtype=kv_cache_dtype) as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
if kv_cache_dtype == "fp8":
attn = model.model.decoder.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
v_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
Expand Down
Loading