Skip to content

Commit fec5cee

Browse files
vanbasten23DamonFool
authored andcommitted
[V1][TPU] Remove unnecessary padding for running on TPU. (vllm-project#14467)
1 parent 25d2769 commit fec5cee

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

vllm/v1/attention/backends/pallas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from vllm.attention.backends.utils import CommonAttentionState
1313

1414
# These are the 2 tunable parameters of the paged attention Pallas kernel.
15-
NUM_QUERIES_PER_BLOCK = 16
16-
NUM_KV_PAGES_PER_BLOCK = 256
15+
NUM_QUERIES_PER_BLOCK = 32
16+
NUM_KV_PAGES_PER_BLOCK = 128
1717

1818

1919
class PallasAttentionBackend(AttentionBackend):

vllm/v1/worker/tpu_model_runner.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
from vllm.sampling_params import SamplingType
2424
from vllm.sequence import IntermediateTensors
2525
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
26-
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
27-
NUM_QUERIES_PER_BLOCK,
28-
PallasAttentionBackend,
26+
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
2927
PallasMetadata)
3028
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
3129
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@@ -78,10 +76,8 @@ def __init__(
7876
self.block_size = cache_config.block_size
7977
self.max_model_len = model_config.max_model_len
8078
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
81-
self.max_num_tokens = _get_padded_number(
82-
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
83-
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
84-
NUM_QUERIES_PER_BLOCK)
79+
self.max_num_tokens = scheduler_config.max_num_batched_tokens
80+
self.max_num_reqs = scheduler_config.max_num_seqs
8581

8682
# Model-related.
8783
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@@ -142,16 +138,8 @@ def __init__(
142138
device="cpu")
143139
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
144140

145-
# self.input_batch.block_table has a shape of [max_num_reqs,
146-
# max_num_blocks_per_req]. To reduce the number of recompilation,
147-
# we want the block_table.shape[0] to be num_tokens.
148-
# To make the block_table to be compatible with the paged attention
149-
# kernel, we want the block_table[1] to be multiple of
150-
# NUM_KV_PAGES_PER_BLOCK.
151-
padded_max_num_blocks_per_req = _get_padded_number(
152-
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
153141
self.block_table_cpu = torch.zeros(
154-
(self.max_num_tokens, padded_max_num_blocks_per_req),
142+
(self.max_num_tokens, self.max_num_blocks_per_req),
155143
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
156144
device="cpu")
157145

0 commit comments

Comments
 (0)