Skip to content

[CB] Override number of Spyre blocks: replace env var with top level argument #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def check_scheduler_inference_steps(
# set env vars
monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
if available_blocks > 0:
monkeypatch.setenv("VLLM_SPYRE_N_BLOCKS", str(available_blocks))
if use_cb:
monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1")

Expand Down Expand Up @@ -90,7 +88,9 @@ def check_scheduler_inference_steps(
tokenizer=model,
max_model_len=max_model_len,
block_size=max_model_len,
max_num_seqs=max_num_seqs)
max_num_seqs=max_num_seqs,
num_gpu_blocks_override=available_blocks
if available_blocks > 0 else None)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
Expand Down
5 changes: 0 additions & 5 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[list[int]] = None
VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[list[int]] = None
VLLM_SPYRE_USE_CB: bool = False
VLLM_SPYRE_N_BLOCKS: int = 0
VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED: int = 0
VLLM_SPYRE_PERF_METRIC_LOGGING_DIR: str = "/tmp"
VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER: bool = False
Expand Down Expand Up @@ -75,10 +74,6 @@ def _backend_backwards_compat() -> str:
"VLLM_SPYRE_USE_CB":
lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))),

# Overriding the number of KV cache blocks available on Spyre (and CPU)
"VLLM_SPYRE_N_BLOCKS":
lambda: int(os.getenv("VLLM_SPYRE_N_BLOCKS", 0)),

# Enable performance metric logging. This captures startup information
# such as warmup times, and loading times. It is turned off by default.
"VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED":
Expand Down
5 changes: 3 additions & 2 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,9 @@ def __init__(

def _set_past_key_value_states(self, num_blocks) -> None:
# overwrite num_blocks for testing scheduler constraints
if envs_spyre.VLLM_SPYRE_N_BLOCKS > 0:
num_blocks = envs_spyre.VLLM_SPYRE_N_BLOCKS
num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override()
if num_blocks_override > 0:
num_blocks = num_blocks_override

# List[layers] of Tuple[k,v] of
# Tensor[num_blocks, block_size, num_kv_heads, head_dim]
Expand Down
9 changes: 9 additions & 0 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class SpyrePlatform(Platform):
supported_quantization: list[str] = ["gptq"]
_warmup_shapes: Optional[tuple[dict[str, int], ...]] = None
_block_size: int = 64 # hardcoded Spyre constraint for now
_num_spyre_blocks_override: int = -1 # override num of KV cache blocks
_config: VllmConfig = None

@classmethod
Expand Down Expand Up @@ -136,6 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# budget available to schedule a full batch
if cache_config is not None:
if envs.VLLM_USE_V1:
# overriding number of available Spyre blocks if not None
if cache_config.num_gpu_blocks_override:
cls._num_spyre_blocks_override = \
cache_config.num_gpu_blocks_override
# The V1 scheduler actually needs 2 blocks for each sequence...
cache_config.num_gpu_blocks_override = \
scheduler_config.max_num_seqs * 2
Expand Down Expand Up @@ -237,6 +242,10 @@ def get_warmup_shapes(
def get_block_size(cls) -> int:
return cls._block_size

@classmethod
def get_num_spyre_blocks_override(cls) -> int:
return cls._num_spyre_blocks_override

@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
"""Returns whether the current platform can support v1 for the supplied
Expand Down
7 changes: 4 additions & 3 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,11 +725,12 @@ def finish_warmup(self) -> None:

def _set_blocks(self, num_blocks: int) -> None:
# overwrite num_blocks for testing scheduler constraints
if envs_spyre.VLLM_SPYRE_N_BLOCKS > 0:
num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override()
if num_blocks_override > 0:
logger.info(
"[WARMUP] Overriding number of KV cache blocks on "
"Spyre/CPU to %d.", envs_spyre.VLLM_SPYRE_N_BLOCKS)
num_blocks = envs_spyre.VLLM_SPYRE_N_BLOCKS
"Spyre/CPU to %d.", num_blocks_override)
num_blocks = num_blocks_override

# set number of available blocks and populate block_pool
self.n_blocks = num_blocks
Expand Down
Loading