Skip to content

Commit 95c19ec

Browse files
YaoJiayiDamonFool
authored andcommitted
[Feat] Support chunked prefill for LMCache connector (vllm-project#14505)
Signed-off-by: YaoJiayi <[email protected]>
1 parent fec5cee commit 95c19ec

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def __init__(
3838
from lmcache.integration.vllm.utils import ENGINE_NAME
3939
from lmcache.integration.vllm.vllm_adapter import (
4040
RetrieveStatus, StoreStatus, init_lmcache_engine,
41-
lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv)
41+
lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
42+
lmcache_store_kv)
4243
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
4344
self.transfer_config)
4445

@@ -54,6 +55,7 @@ def __init__(
5455
self.cache_config = config.cache_config
5556
self.lmcache_retrieve_kv = lmcache_retrieve_kv
5657
self.lmcache_store_kv = lmcache_store_kv
58+
self.lmcache_should_retrieve = lmcache_should_retrieve
5759
self.lmcache_should_store = lmcache_should_store
5860
self.store_status = StoreStatus
5961
self.retrieve_status = RetrieveStatus
@@ -65,15 +67,11 @@ def recv_kv_caches_and_hidden_states(
6567
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
6668
"ModelInputForGPUWithSamplingMetadata"]:
6769

68-
hidden_or_intermediate_states = None
69-
70-
# TODO (Jiayi): Need to support chunked prefill
71-
retrieve_status = self.retrieve_status.PREFILL
72-
73-
model_input, bypass_model_exec = self.lmcache_retrieve_kv(
74-
model_executable, model_input, self.cache_config, kv_caches,
75-
retrieve_status)
76-
70+
retrieve_status = self.lmcache_should_retrieve(model_input)
71+
model_input, bypass_model_exec, hidden_or_intermediate_states =\
72+
self.lmcache_retrieve_kv(
73+
model_executable, model_input, self.cache_config, kv_caches,
74+
retrieve_status)
7775
return hidden_or_intermediate_states, bypass_model_exec, model_input
7876

7977
def send_kv_caches_and_hidden_states(
@@ -84,15 +82,7 @@ def send_kv_caches_and_hidden_states(
8482
hidden_or_intermediate_states: Union[torch.Tensor,
8583
IntermediateTensors],
8684
) -> None:
87-
num_reqs = 0
88-
seq_group_list = model_input.sampling_metadata.seq_groups
89-
assert seq_group_list is not None
90-
for seq_group in seq_group_list:
91-
seq_ids = seq_group.seq_ids
92-
for seq_id in seq_ids:
93-
num_reqs += 1
94-
95-
# TODO (Jiayi): Only normal prefill is supported for now
85+
9686
store_status = self.lmcache_should_store(model_input)
9787
self.lmcache_store_kv(
9888
self.model_config,

0 commit comments

Comments
 (0)