Skip to content

Use KV cache constant names provided by compressed tensors #1200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any, Dict, Optional, Tuple

import torch
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
from compressed_tensors.quantization import (
KVCacheScaleType,
QuantizationStatus,
is_attention_module,
)
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data
Expand Down Expand Up @@ -194,8 +198,10 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""
kv_cache = getattr(module, "kv_cache")
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")
k_scale = kv_cache.k_scales[module.layer_idx]
v_scale = kv_cache.v_scales[module.layer_idx]
update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)


def set_unset_kv_cache(module: Module):
Expand Down