Skip to content

Commit aecd732

Browse files
authored
[CB] refactoring warmup for batch size 1 (#347)
### [CB] refactoring warmup for batch size 1 From #312 (comment) there is a request for a nicer integration of batch size 1 support during warmup. Most of the code is already on main, thus this PR. Signed-off-by: Yannick Schnider <[email protected]>
1 parent ee6e224 commit aecd732

File tree

2 files changed

+14
-30
lines changed

2 files changed

+14
-30
lines changed

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,13 @@ def load_model(self, prompt_lens: Iterable[int],
290290
)
291291

292292
def build_input_batch(self) -> SamplingInputBatch:
293+
# Fix for batch size 1: set input batch to fit 2 requests for warmup,
294+
# and reset input batch to fit max_num_seqs requests after warmup
295+
min_seqs_required = 2 if self.warmup_mode else 1
296+
293297
return SamplingInputBatch(
294-
max_num_reqs=self.scheduler_config.max_num_seqs,
298+
max_num_reqs=max(min_seqs_required,
299+
self.scheduler_config.max_num_seqs),
295300
max_model_len=self.model_config.max_model_len,
296301
device=self.device,
297302
pin_memory=self.pin_memory,
@@ -802,7 +807,10 @@ def __init__(
802807
vocab_size=vllm_config.model_config.get_vocab_size(),
803808
)
804809

805-
def finish_warmup(self) -> None:
810+
def complete_warmup(self) -> None:
811+
super().complete_warmup()
812+
# Fix for batch size 1: need to update the input_batch after the warmup
813+
self.input_batch = self.build_input_batch()
806814
# get the number or pages from the actual Spyre card after the warmup
807815
# and set it accordingly in the model runner and the kv cache size
808816
n_blocks_avail = self._get_num_blocks_available()

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import vllm_spyre.perf_metrics as perf_metrics
3030
from vllm_spyre.model_executor.model_loader import spyre_setup
3131
from vllm_spyre.platform import SpyrePlatform
32-
from vllm_spyre.v1.worker.spyre_input_batch import SamplingInputBatch
3332
from vllm_spyre.v1.worker.spyre_model_runner import (
3433
ContinuousBatchingSpyreModelRunner, SpyrePoolingModelRunner,
3534
StaticBatchingSpyreModelRunner, SupportedTask)
@@ -110,6 +109,9 @@ def compile_or_warm_up_model(self) -> None:
110109
prompt_len, num_decode_tokens, batch_size)
111110
self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens,
112111
self.restricted_tokens, batch_size)
112+
113+
self.model_runner.complete_warmup()
114+
113115
all_warmup_end_t = time.time()
114116
all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
115117
self.perf_metrics.log("total warmup time", all_warmup_total_t)
@@ -119,7 +121,6 @@ def compile_or_warm_up_model(self) -> None:
119121
"[WARMUP] All %d prompt/decode/batchsize-shape "
120122
"combinations finished in %.3fs", num_shape_combinations,
121123
all_warmup_total_t)
122-
self.model_runner.complete_warmup()
123124

124125
def check_health(self) -> None:
125126
"""Basic health check (override for device-specific checks)."""
@@ -339,18 +340,6 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
339340
prompt_len = 42
340341
num_decode_tokens = 2
341342

342-
# Fix for batch size 1: set input batch to fit 2 requests for warmup
343-
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1:
344-
model_runner.input_batch = SamplingInputBatch(
345-
max_num_reqs=2,
346-
max_model_len=model_runner.vllm_config.model_config.
347-
max_model_len,
348-
device=model_runner.device,
349-
pin_memory=model_runner.pin_memory,
350-
vocab_size=model_runner.vllm_config.model_config.
351-
get_vocab_size(),
352-
)
353-
354343
# Sample from the valid token ids
355344
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
356345
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))]
@@ -398,20 +387,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
398387
self.execute_model(scheduler_output)
399388
self._cleanup_model_runner(request=[add_dummy_request])
400389

401-
# Fix for batch size 1: reset input batch to fit max_num_seqs requests
402-
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1:
403-
model_runner.input_batch = SamplingInputBatch(
404-
max_num_reqs=model_runner.vllm_config.scheduler_config.
405-
max_num_seqs,
406-
max_model_len=model_runner.vllm_config.model_config.
407-
max_model_len,
408-
device=model_runner.device,
409-
pin_memory=model_runner.pin_memory,
410-
vocab_size=model_runner.vllm_config.model_config.
411-
get_vocab_size(),
412-
)
413-
414-
model_runner.finish_warmup()
390+
model_runner.complete_warmup()
415391

416392
warmup_end_t = time.time()
417393
warmup_total_t = warmup_end_t - warmup_start_t

0 commit comments

Comments
 (0)