Skip to content

Commit a2ae496

Browse files
authored
[CPU] Support FP8 KV cache (#14741)
Signed-off-by: jiang1.li <[email protected]>
1 parent 877e352 commit a2ae496

File tree

8 files changed

+122
-36
lines changed

8 files changed

+122
-36
lines changed

csrc/cpu/cache.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
#include "cpu_types.hpp"
55

6+
#if defined(__x86_64__)
7+
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
8+
#else
9+
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
10+
#endif
11+
612
namespace {
713
template <typename scalar_t>
814
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
@@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
95101
}
96102

97103
const int element_num_per_block = key_caches[0][0].numel();
98-
VLLM_DISPATCH_FLOATING_TYPES(
99-
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
100-
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
101-
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
102-
element_num_per_block, num_layers);
103-
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
104-
});
104+
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
105+
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
106+
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
107+
element_num_per_block, num_layers);
108+
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
109+
});
105110
}
106111

107112
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
@@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
118123
int key_stride = key.stride(0);
119124
int value_stride = value.stride(0);
120125

121-
VLLM_DISPATCH_FLOATING_TYPES(
122-
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
123-
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
124-
reshape_and_cache_cpu_impl<scalar_t>(
125-
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
126-
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
127-
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
128-
value_stride, num_heads, head_size, block_size, x);
129-
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
130-
});
126+
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
127+
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
128+
reshape_and_cache_cpu_impl<scalar_t>(
129+
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
130+
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
131+
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
132+
num_heads, head_size, block_size, x);
133+
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
134+
});
131135
}
132136

133137
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,

csrc/cpu/cpu_types_x86.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,18 @@ namespace vec_op {
1616
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
1717
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
1818

19+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
20+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
21+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
22+
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
23+
1924
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
2025
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
2126

27+
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
28+
AT_DISPATCH_SWITCH(TYPE, NAME, \
29+
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
30+
2231
#ifndef CPU_OP_GUARD
2332
#define CPU_KERNEL_GUARD_IN(NAME)
2433
#define CPU_KERNEL_GUARD_OUT(NAME)

docs/source/getting_started/installation/cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ vLLM CPU backend supports the following vLLM features:
189189
- Model Quantization (`INT8 W8A8, AWQ, GPTQ`)
190190
- Chunked-prefill
191191
- Prefix-caching
192-
- FP8-E5M2 KV-Caching (TODO)
192+
- FP8-E5M2 KV cache
193193

194194
## Related runtime environment variables
195195

tests/basic_correctness/test_chunked_prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_with_prefix_caching(
266266

267267

268268
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
269-
@pytest.mark.parametrize("dtype", ["bfloat16"])
269+
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
270270
@pytest.mark.parametrize("max_tokens", [32])
271271
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
272272
@pytest.mark.parametrize("enforce_eager", [False])
@@ -303,7 +303,7 @@ def test_models_cpu(
303303
@pytest.mark.parametrize("max_tokens", [16])
304304
@pytest.mark.parametrize("enforce_eager", [False])
305305
@pytest.mark.parametrize("chunk_size", [30, 32])
306-
@pytest.mark.parametrize("dtype", ["bfloat16"])
306+
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
307307
@pytest.mark.cpu_model
308308
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
309309
def test_with_prefix_caching_cpu(

tests/models/decoder_only/language/test_fp8.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from tests.kernels.utils import override_backend_env_variable
1313
from tests.quantization.utils import is_quant_method_supported
14+
from vllm.platforms import current_platform
1415

1516
from ...utils import check_logprobs_close
1617

@@ -93,3 +94,63 @@ def test_models(
9394
name_0="fp16_kv_cache",
9495
name_1="fp8_kv_cache",
9596
)
97+
98+
99+
@pytest.mark.cpu_model
100+
@pytest.mark.skipif(not current_platform.is_cpu(),
101+
reason="test for the CPU backend.")
102+
@pytest.mark.parametrize(
103+
"kv_cache_dtype,base_model,test_model",
104+
[
105+
# Test BF16 checkpoint w. fp8_e5m2 kv-cache.
106+
("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct",
107+
"meta-llama/Llama-3.2-1B-Instruct"),
108+
])
109+
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
110+
@pytest.mark.parametrize("max_tokens", [4])
111+
# Due to low-precision numerical divergence, this test is too sensitive for
112+
# the async postprocessor
113+
@pytest.mark.parametrize("disable_async_output_proc", [True])
114+
def test_cpu_models(
115+
vllm_runner,
116+
example_prompts,
117+
kv_cache_dtype: str,
118+
base_model: str,
119+
test_model: str,
120+
max_tokens: int,
121+
disable_async_output_proc: bool,
122+
) -> None:
123+
"""
124+
Only checks log probs match to cover the discrepancy in
125+
numerical sensitive kernels.
126+
"""
127+
128+
MAX_MODEL_LEN = 1024
129+
NUM_LOG_PROBS = 8
130+
131+
with vllm_runner(
132+
base_model,
133+
max_model_len=MAX_MODEL_LEN,
134+
dtype="bfloat16",
135+
kv_cache_dtype="auto",
136+
disable_async_output_proc=disable_async_output_proc,
137+
) as vllm_model:
138+
baseline_outputs = vllm_model.generate_greedy_logprobs(
139+
example_prompts, max_tokens, NUM_LOG_PROBS)
140+
141+
with vllm_runner(
142+
test_model,
143+
max_model_len=MAX_MODEL_LEN,
144+
dtype="bfloat16",
145+
kv_cache_dtype=kv_cache_dtype,
146+
disable_async_output_proc=disable_async_output_proc,
147+
) as vllm_model:
148+
test_outputs = vllm_model.generate_greedy_logprobs(
149+
example_prompts, max_tokens, NUM_LOG_PROBS)
150+
151+
check_logprobs_close(
152+
outputs_0_lst=baseline_outputs,
153+
outputs_1_lst=test_outputs,
154+
name_0="bf16_kv_cache",
155+
name_1="fp8_kv_cache",
156+
)

vllm/attention/backends/torch_sdpa.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
is_quantized_kv_cache)
1818
# yapf: enable
1919
from vllm.attention.backends.utils import CommonAttentionState
20-
from vllm.attention.ops.ipex_attn import PagedAttention
20+
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
2121
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
2222
from vllm.logger import init_logger
2323
from vllm.utils import make_tensor_with_pad
@@ -431,10 +431,11 @@ def __init__(
431431
raise ValueError(
432432
f"Head size {head_size} is not supported by PagedAttention. "
433433
f"Supported head sizes are: {supported_head_sizes}.")
434-
if is_quantized_kv_cache(kv_cache_dtype):
434+
435+
if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
435436
raise NotImplementedError(
436-
"Torch SDPA backend does not support FP8 KV cache. "
437-
"Please use xFormers backend instead.")
437+
"Torch SDPA backend FP8 KV cache requires "
438+
"intel_extension_for_pytorch support.")
438439
self.attn_type = attn_type
439440

440441
def forward(

vllm/platforms/cpu.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,32 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6060
# Reminder: Please update docs/source/features/compatibility_matrix.md
6161
# If the feature combo become valid
6262
if not model_config.enforce_eager:
63-
logger.warning(
64-
"CUDA graph is not supported on CPU, fallback to the eager "
65-
"mode.")
6663
model_config.enforce_eager = True
6764

6865
cache_config = vllm_config.cache_config
6966

7067
if cache_config and cache_config.block_size is None:
7168
cache_config.block_size = 16
7269

70+
scheduler_config = vllm_config.scheduler_config
71+
if ((scheduler_config.chunked_prefill_enabled
72+
or cache_config.enable_prefix_caching)
73+
and cache_config.cache_dtype != "auto"):
74+
raise RuntimeError("Chunked-prefill and prefix-cache on the CPU "
75+
"backend is not compatible with FP8 KV cache.")
76+
77+
if cache_config.cache_dtype == "fp8_e4m3":
78+
cache_config.cache_dtype = "fp8_e5m2"
79+
logger.warning(
80+
"CPU backend doesn't support fp8_e4m3 KV cache type, "
81+
"cast to fp8_e5m2.")
82+
83+
if (cache_config.cache_dtype != "auto"
84+
and model_config.dtype == torch.half):
85+
logger.warning("FP8 KV cache on the CPU backend only does not"
86+
" support fp16 for now, cast to bf16.")
87+
model_config.dtype = torch.bfloat16
88+
7389
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
7490

7591
if kv_cache_space >= 0:
@@ -85,14 +101,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
85101
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
86102
f" {kv_cache_space}, expect a positive integer value.")
87103

88-
scheduler_config = vllm_config.scheduler_config
89-
if ((scheduler_config.chunked_prefill_enabled
90-
or cache_config.enable_prefix_caching)
91-
and model_config.dtype == torch.half):
92-
logger.warning("Chunked-prefill on the CPU backend only does not"
93-
" support fp16 for now, cast to bf16.")
94-
model_config.dtype = torch.bfloat16
95-
96104
parallel_config = vllm_config.parallel_config
97105
if (parallel_config.distributed_executor_backend is not None
98106
and parallel_config.distributed_executor_backend != "mp"):

vllm/worker/cpu_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
5353

5454
if cache_config.cache_dtype == "auto":
5555
self.dtype = model_config.dtype
56+
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
57+
self.dtype = torch.float8_e5m2
5658
else:
57-
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
59+
raise NotImplementedError(f"Unsupported KV cache type "
60+
f"{cache_config.cache_dtype}.")
5861

5962
# Get attention backend.
6063
self.attn_backend = get_attn_backend(

0 commit comments

Comments
 (0)