2
2
3
3
# ruff: noqa: SIM117
4
4
from pathlib import Path
5
- from typing import List , Optional , Tuple
5
+ from typing import Optional
6
6
7
7
import openvino as ov
8
8
import torch
12
12
from torch import nn
13
13
14
14
import vllm .envs as envs
15
- from vllm .attention .backends .openvino import OpenVINOAttentionMetadata
16
15
from vllm .config import ModelConfig , VllmConfig , set_current_vllm_config
16
+ from vllm .forward_context import get_forward_context
17
17
from vllm .logger import init_logger
18
18
from vllm .model_executor .layers .logits_processor import (LogitsProcessor ,
19
19
_prune_hidden_states )
24
24
logger = init_logger (__name__ )
25
25
26
26
27
- def _flattenize_inputs (inputs ):
27
+ def _flatten_inputs (inputs ):
28
28
"""
29
29
Helper function for making nested inputs flattens
30
30
"""
@@ -33,10 +33,9 @@ def _flattenize_inputs(inputs):
33
33
if input_data is None :
34
34
continue
35
35
if isinstance (input_data , (list , tuple )):
36
- flatten_inputs .extend (_flattenize_inputs (input_data ))
36
+ flatten_inputs .extend (_flatten_inputs (input_data ))
37
37
elif isinstance (input_data , dict ):
38
- flatten_inputs .extend (_flattenize_inputs (list (
39
- input_data .values ())))
38
+ flatten_inputs .extend (_flatten_inputs (list (input_data .values ())))
40
39
else :
41
40
flatten_inputs .append (input_data )
42
41
return flatten_inputs
@@ -147,15 +146,15 @@ def forward(
147
146
self ,
148
147
input_ids : torch .Tensor ,
149
148
positions : torch .Tensor ,
150
- kv_caches : List [Tuple [ov .Tensor , ov .Tensor ]],
151
- attn_metadata : OpenVINOAttentionMetadata ,
149
+ kv_caches : list [tuple [ov .Tensor , ov .Tensor ]],
152
150
) -> torch .Tensor :
153
- flatten_kv_cache = _flattenize_inputs (kv_caches )
151
+ flat_kv_caches = _flatten_inputs (kv_caches )
152
+ attn_metadata = get_forward_context ().attn_metadata
154
153
155
154
inputs = [
156
155
input_ids ,
157
156
positions ,
158
- * flatten_kv_cache ,
157
+ * flat_kv_caches ,
159
158
attn_metadata .past_lens ,
160
159
attn_metadata .subsequence_begins ,
161
160
attn_metadata .block_indices ,
0 commit comments