Skip to content

Commit cc01461

Browse files
Chenyaaangm-misiura
authored andcommitted
[TPU] Add TPU specific var VLLM_TPU_MOST_MODEL_LEN (vllm-project#19919)
Signed-off-by: Chenyaaang <[email protected]>
1 parent 2360d7b commit cc01461

File tree

5 files changed

+185
-77
lines changed

5 files changed

+185
-77
lines changed

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
587587
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
588588
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
589589
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
590+
591+
592+
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
593+
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
594+
vllm_config = get_vllm_config()
595+
vllm_config.model_config.max_model_len = 32000
596+
vllm_config.scheduler_config.max_num_seqs = 1200
597+
model_runner = get_model_runner(vllm_config)
598+
599+
# verify model runner will adjust num_reqs to avoid SMEM OOM.
600+
assert model_runner.num_reqs_most_model_len == 1200
601+
# num_page_per_req = 32k // 128
602+
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
603+
assert model_runner.num_reqs_max_model_len == 524

vllm/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
120120
VLLM_V0_USE_OUTLINES_CACHE: bool = False
121121
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
122+
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
122123
VLLM_USE_DEEP_GEMM: bool = False
123124
VLLM_XGRAMMAR_CACHE_MB: int = 0
124125
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@@ -833,6 +834,8 @@ def get_vllm_port() -> Optional[int]:
833834
"VLLM_TPU_BUCKET_PADDING_GAP":
834835
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
835836
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
837+
"VLLM_TPU_MOST_MODEL_LEN":
838+
lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)),
836839

837840
# Allow use of DeepGemm kernels for fused moe ops.
838841
"VLLM_USE_DEEP_GEMM":

vllm/platforms/tpu.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
122122
PallasAttentionBackend)
123123
cache_config.block_size = PallasAttentionBackend.get_page_size(
124124
vllm_config) # type: ignore[assignment]
125-
min_page_size = PallasAttentionBackend.get_min_page_size(
126-
vllm_config)
127-
if min_page_size > cache_config.block_size:
128-
logger.warning(
129-
"Increase the page size from %s to %s to make sure there's"
130-
"no SMEM OOM",
131-
cache_config.block_size,
132-
min_page_size,
133-
)
134-
cache_config.block_size = min_page_size # type: ignore[assignment]
135125

136126
parallel_config = vllm_config.parallel_config
137127
scheduler_config = vllm_config.scheduler_config

vllm/v1/attention/backends/pallas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def get_min_page_size(vllm_config: VllmConfig) -> int:
7171
min_page_size = 1 << (min_page_size - 1).bit_length()
7272
return min_page_size
7373

74+
@staticmethod
75+
def get_max_num_seqs(model_len: int, page_size: int) -> int:
76+
num_page_per_req = cdiv(model_len, page_size)
77+
return 1024 * 1024 // 2 // num_page_per_req // 4
78+
7479
# TPU has limited SREGs (scalar registers), if page_size is too small, we
7580
# can spill SREGs easily which leads to bad performance. The strategy we
7681
# apply here is trying to split max-model-len to 16 pages which make the

0 commit comments

Comments
 (0)