Skip to content

[TPU] kv cache update kernel supports dynamic grid #20235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/v1/tpu/test_kv_cache_update_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
dtype=np.int32)
num_kv_update_slices = len(slice_lens)
kv_cache_start_indices = np.array([
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
Expand All @@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
device="cpu",
dtype=torch.int32)
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
num_kv_update_slices_xla = torch.tensor([num_kv_update_slices],
device=torch_xla.device(),
dtype=torch.int32)
torch_xla.sync()

torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
num_slices_per_block)
new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla,
page_size, num_slices_per_block)
kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()

Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/ops/pallas_kv_cache_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

from vllm.utils import cdiv


def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
# slice_len)
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
# Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
Expand Down Expand Up @@ -70,6 +72,7 @@ def kv_cache_update(
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache: jax.
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
num_kv_update_slices: jax.Array, # [1]
*,
page_size: int = 32,
num_slices_per_block: int = 8,
Expand Down Expand Up @@ -107,7 +110,7 @@ def kv_cache_update(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=(slices.shape[1] // num_slices_per_block, ),
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
Expand Down
34 changes: 22 additions & 12 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class PallasMetadata:
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
num_kv_update_slices: torch.Tensor
num_slices_per_kv_cache_update_block: int


Expand Down Expand Up @@ -219,7 +220,8 @@ def forward(
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(
key, value, kv_cache, slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block)
attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices)

output = torch.ops.xla.ragged_paged_attention(
query,
Expand Down Expand Up @@ -252,6 +254,7 @@ def write_to_kv_cache(
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
) -> None:
""" Write the key and values to the KV cache.

Expand All @@ -271,40 +274,47 @@ def write_to_kv_cache(

kv_cache = kv_cache.flatten(0, 1)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv, slot_mapping, kv_cache, page_size,
kv, slot_mapping, kv_cache, num_kv_update_slices, page_size,
num_slices_per_kv_cache_update_block)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)


@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
new_kv_cache = xb.call_jax(
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache


XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
"int page_size, int num_slices_per_block) -> Tensor", )
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
"-> Tensor", )
Comment on lines 297 to +300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It would be helpful to include a space after the concatenation operator (+) for improved readability.

    "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
    "-> Tensor", )



@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
page_size, num_slices_per_block)
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache


@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
8 changes: 8 additions & 0 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,8 +711,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
self.device)
block_tables = block_tables.to(self.device)

# Calculate the slot mapping
slot_mapping_metadata = self._get_slot_mapping_metadata(
num_reqs, num_scheduled_tokens_per_req)
num_kv_update_slices = slot_mapping_metadata.shape[0]
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size)
Expand Down Expand Up @@ -743,6 +745,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
num_seqs=torch.tensor([num_reqs],
dtype=torch.int32,
device=self.device),
num_kv_update_slices=torch.tensor([num_kv_update_slices],
dtype=torch.int32,
device=self.device),
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
)
Expand Down Expand Up @@ -1172,6 +1177,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
dtype=torch.int32).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
num_tokens, self.max_num_reqs, self.block_size)
num_kv_update_slices = torch.tensor([padded_num_slices],
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros((3, padded_num_slices),
dtype=torch.int32).to(self.device)
block_tables = torch.zeros((num_reqs, num_blocks),
Expand All @@ -1191,6 +1198,7 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
)
Expand Down