@@ -38,7 +38,8 @@ def __init__(
38
38
from lmcache .integration .vllm .utils import ENGINE_NAME
39
39
from lmcache .integration .vllm .vllm_adapter import (
40
40
RetrieveStatus , StoreStatus , init_lmcache_engine ,
41
- lmcache_retrieve_kv , lmcache_should_store , lmcache_store_kv )
41
+ lmcache_retrieve_kv , lmcache_should_retrieve , lmcache_should_store ,
42
+ lmcache_store_kv )
42
43
logger .info ("Initializing LMCacheConfig under kv_transfer_config %s" ,
43
44
self .transfer_config )
44
45
@@ -54,6 +55,7 @@ def __init__(
54
55
self .cache_config = config .cache_config
55
56
self .lmcache_retrieve_kv = lmcache_retrieve_kv
56
57
self .lmcache_store_kv = lmcache_store_kv
58
+ self .lmcache_should_retrieve = lmcache_should_retrieve
57
59
self .lmcache_should_store = lmcache_should_store
58
60
self .store_status = StoreStatus
59
61
self .retrieve_status = RetrieveStatus
@@ -65,15 +67,11 @@ def recv_kv_caches_and_hidden_states(
65
67
) -> Tuple [Union [torch .Tensor , IntermediateTensors ], bool ,
66
68
"ModelInputForGPUWithSamplingMetadata" ]:
67
69
68
- hidden_or_intermediate_states = None
69
-
70
- # TODO (Jiayi): Need to support chunked prefill
71
- retrieve_status = self .retrieve_status .PREFILL
72
-
73
- model_input , bypass_model_exec = self .lmcache_retrieve_kv (
74
- model_executable , model_input , self .cache_config , kv_caches ,
75
- retrieve_status )
76
-
70
+ retrieve_status = self .lmcache_should_retrieve (model_input )
71
+ model_input , bypass_model_exec , hidden_or_intermediate_states = \
72
+ self .lmcache_retrieve_kv (
73
+ model_executable , model_input , self .cache_config , kv_caches ,
74
+ retrieve_status )
77
75
return hidden_or_intermediate_states , bypass_model_exec , model_input
78
76
79
77
def send_kv_caches_and_hidden_states (
@@ -84,15 +82,7 @@ def send_kv_caches_and_hidden_states(
84
82
hidden_or_intermediate_states : Union [torch .Tensor ,
85
83
IntermediateTensors ],
86
84
) -> None :
87
- num_reqs = 0
88
- seq_group_list = model_input .sampling_metadata .seq_groups
89
- assert seq_group_list is not None
90
- for seq_group in seq_group_list :
91
- seq_ids = seq_group .seq_ids
92
- for seq_id in seq_ids :
93
- num_reqs += 1
94
-
95
- # TODO (Jiayi): Only normal prefill is supported for now
85
+
96
86
store_status = self .lmcache_should_store (model_input )
97
87
self .lmcache_store_kv (
98
88
self .model_config ,
0 commit comments