|
23 | 23 | from vllm.sampling_params import SamplingType
|
24 | 24 | from vllm.sequence import IntermediateTensors
|
25 | 25 | from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
26 |
| -from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, |
27 |
| - NUM_QUERIES_PER_BLOCK, |
28 |
| - PallasAttentionBackend, |
| 26 | +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, |
29 | 27 | PallasMetadata)
|
30 | 28 | from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
31 | 29 | from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
@@ -78,10 +76,8 @@ def __init__(
|
78 | 76 | self.block_size = cache_config.block_size
|
79 | 77 | self.max_model_len = model_config.max_model_len
|
80 | 78 | self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
81 |
| - self.max_num_tokens = _get_padded_number( |
82 |
| - scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK) |
83 |
| - self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs, |
84 |
| - NUM_QUERIES_PER_BLOCK) |
| 79 | + self.max_num_tokens = scheduler_config.max_num_batched_tokens |
| 80 | + self.max_num_reqs = scheduler_config.max_num_seqs |
85 | 81 |
|
86 | 82 | # Model-related.
|
87 | 83 | self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
@@ -142,16 +138,8 @@ def __init__(
|
142 | 138 | device="cpu")
|
143 | 139 | self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
144 | 140 |
|
145 |
| - # self.input_batch.block_table has a shape of [max_num_reqs, |
146 |
| - # max_num_blocks_per_req]. To reduce the number of recompilation, |
147 |
| - # we want the block_table.shape[0] to be num_tokens. |
148 |
| - # To make the block_table to be compatible with the paged attention |
149 |
| - # kernel, we want the block_table[1] to be multiple of |
150 |
| - # NUM_KV_PAGES_PER_BLOCK. |
151 |
| - padded_max_num_blocks_per_req = _get_padded_number( |
152 |
| - self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) |
153 | 141 | self.block_table_cpu = torch.zeros(
|
154 |
| - (self.max_num_tokens, padded_max_num_blocks_per_req), |
| 142 | + (self.max_num_tokens, self.max_num_blocks_per_req), |
155 | 143 | dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
156 | 144 | device="cpu")
|
157 | 145 |
|
|
0 commit comments