Skip to content

Commit e85d84b

Browse files
hmellorlulmer
authored andcommitted
Fix missing kv_caches and attn_metadata in OpenVINOCausalLM (vllm-project#14271)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 8c4ebe3 commit e85d84b

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

vllm/model_executor/model_loader/openvino.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# ruff: noqa: SIM117
44
from pathlib import Path
5-
from typing import List, Optional, Tuple
5+
from typing import Optional
66

77
import openvino as ov
88
import torch
@@ -12,8 +12,8 @@
1212
from torch import nn
1313

1414
import vllm.envs as envs
15-
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
1615
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
16+
from vllm.forward_context import get_forward_context
1717
from vllm.logger import init_logger
1818
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
1919
_prune_hidden_states)
@@ -24,7 +24,7 @@
2424
logger = init_logger(__name__)
2525

2626

27-
def _flattenize_inputs(inputs):
27+
def _flatten_inputs(inputs):
2828
"""
2929
Helper function for making nested inputs flattens
3030
"""
@@ -33,10 +33,9 @@ def _flattenize_inputs(inputs):
3333
if input_data is None:
3434
continue
3535
if isinstance(input_data, (list, tuple)):
36-
flatten_inputs.extend(_flattenize_inputs(input_data))
36+
flatten_inputs.extend(_flatten_inputs(input_data))
3737
elif isinstance(input_data, dict):
38-
flatten_inputs.extend(_flattenize_inputs(list(
39-
input_data.values())))
38+
flatten_inputs.extend(_flatten_inputs(list(input_data.values())))
4039
else:
4140
flatten_inputs.append(input_data)
4241
return flatten_inputs
@@ -147,15 +146,15 @@ def forward(
147146
self,
148147
input_ids: torch.Tensor,
149148
positions: torch.Tensor,
150-
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
151-
attn_metadata: OpenVINOAttentionMetadata,
149+
kv_caches: list[tuple[ov.Tensor, ov.Tensor]],
152150
) -> torch.Tensor:
153-
flatten_kv_cache = _flattenize_inputs(kv_caches)
151+
flat_kv_caches = _flatten_inputs(kv_caches)
152+
attn_metadata = get_forward_context().attn_metadata
154153

155154
inputs = [
156155
input_ids,
157156
positions,
158-
*flatten_kv_cache,
157+
*flat_kv_caches,
159158
attn_metadata.past_lens,
160159
attn_metadata.subsequence_begins,
161160
attn_metadata.block_indices,

vllm/worker/openvino_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ def execute_model(
346346
input_tokens,
347347
"positions":
348348
input_positions,
349+
"kv_caches":
350+
kv_caches,
349351
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
350352
device=self.device),
351353
}

0 commit comments

Comments
 (0)