Skip to content

Commit fc4dd68

Browse files
kzawora-intelxuechendi
authored andcommitted
Fix kv_cache_dtype handling for out-of-tree HPU plugin (vllm-project#21302)
Signed-off-by: Konrad Zawora <[email protected]> Signed-off-by: Chendi.Xue <[email protected]> Co-authored-by: Chendi.Xue <[email protected]>
1 parent f206a26 commit fc4dd68

File tree

5 files changed

+30
-16
lines changed

5 files changed

+30
-16
lines changed

vllm/engine/arg_utils.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,22 +1352,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13521352

13531353
# No Fp8 KV cache so far.
13541354
if self.kv_cache_dtype != "auto":
1355-
fp8_attention = self.kv_cache_dtype.startswith("fp8")
1356-
will_use_fa = (
1357-
current_platform.is_cuda()
1358-
and not envs.is_set("VLLM_ATTENTION_BACKEND")
1359-
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
1360-
supported = False
1361-
if (current_platform.is_rocm()
1362-
or (current_platform.is_cuda()
1363-
and current_platform.is_device_capability(100))
1364-
or current_platform.is_tpu()):
1365-
supported = True
1366-
elif fp8_attention and will_use_fa:
1367-
from vllm.attention.utils.fa_utils import (
1368-
flash_attn_supports_fp8)
1369-
supported = flash_attn_supports_fp8()
1370-
1355+
supported = current_platform.is_kv_cache_dtype_supported(
1356+
self.kv_cache_dtype)
13711357
if not supported:
13721358
_raise_or_fallback(feature_name="--kv-cache-dtype",
13731359
recommend_to_remove=False)

vllm/platforms/cuda.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,19 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
586586
" not found. Assuming no NVLink available.")
587587
return False
588588

589+
@classmethod
590+
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
591+
fp8_attention = kv_cache_dtype.startswith("fp8")
592+
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
593+
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
594+
supported = False
595+
if cls.is_device_capability(100):
596+
supported = True
597+
elif fp8_attention and will_use_fa:
598+
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
599+
supported = flash_attn_supports_fp8()
600+
return supported
601+
589602

590603
# Autodetect either NVML-enabled or non-NVML platform
591604
# based on whether NVML is available.

vllm/platforms/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ def stateless_init_device_torch_dist_pg(
543543
"""
544544
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
545545

546+
@classmethod
547+
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
548+
"""
549+
Returns if the kv_cache_dtype is supported by the current platform.
550+
"""
551+
return False
552+
546553

547554
class UnspecifiedPlatform(Platform):
548555
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/rocm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,7 @@ def stateless_init_device_torch_dist_pg(
454454
@classmethod
455455
def device_count(cls) -> int:
456456
return cuda_device_count_stateless()
457+
458+
@classmethod
459+
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
460+
return True

vllm/platforms/tpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ def validate_request(
190190
and params.sampling_type == SamplingType.RANDOM_SEED):
191191
raise ValueError("Torch XLA does not support per-request seed.")
192192

193+
@classmethod
194+
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
195+
return True
196+
193197

194198
try:
195199
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform

0 commit comments

Comments
 (0)