29
29
import vllm_spyre .perf_metrics as perf_metrics
30
30
from vllm_spyre .model_executor .model_loader import spyre_setup
31
31
from vllm_spyre .platform import SpyrePlatform
32
- from vllm_spyre .v1 .worker .spyre_input_batch import SamplingInputBatch
33
32
from vllm_spyre .v1 .worker .spyre_model_runner import (
34
33
ContinuousBatchingSpyreModelRunner , SpyrePoolingModelRunner ,
35
34
StaticBatchingSpyreModelRunner , SupportedTask )
@@ -110,6 +109,9 @@ def compile_or_warm_up_model(self) -> None:
110
109
prompt_len , num_decode_tokens , batch_size )
111
110
self ._warmup_spyre_fixed_size (prompt_len , num_decode_tokens ,
112
111
self .restricted_tokens , batch_size )
112
+
113
+ self .model_runner .complete_warmup ()
114
+
113
115
all_warmup_end_t = time .time ()
114
116
all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
115
117
self .perf_metrics .log ("total warmup time" , all_warmup_total_t )
@@ -119,7 +121,6 @@ def compile_or_warm_up_model(self) -> None:
119
121
"[WARMUP] All %d prompt/decode/batchsize-shape "
120
122
"combinations finished in %.3fs" , num_shape_combinations ,
121
123
all_warmup_total_t )
122
- self .model_runner .complete_warmup ()
123
124
124
125
def check_health (self ) -> None :
125
126
"""Basic health check (override for device-specific checks)."""
@@ -339,18 +340,6 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
339
340
prompt_len = 42
340
341
num_decode_tokens = 2
341
342
342
- # Fix for batch size 1: set input batch to fit 2 requests for warmup
343
- if model_runner .vllm_config .scheduler_config .max_num_seqs == 1 :
344
- model_runner .input_batch = SamplingInputBatch (
345
- max_num_reqs = 2 ,
346
- max_model_len = model_runner .vllm_config .model_config .
347
- max_model_len ,
348
- device = model_runner .device ,
349
- pin_memory = model_runner .pin_memory ,
350
- vocab_size = model_runner .vllm_config .model_config .
351
- get_vocab_size (),
352
- )
353
-
354
343
# Sample from the valid token ids
355
344
warmup_tokens_tensor = valid_token_ids_tensor [torch .randint (
356
345
0 , len (valid_token_ids_tensor ), (batch_size + 1 , prompt_len ))]
@@ -398,20 +387,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
398
387
self .execute_model (scheduler_output )
399
388
self ._cleanup_model_runner (request = [add_dummy_request ])
400
389
401
- # Fix for batch size 1: reset input batch to fit max_num_seqs requests
402
- if model_runner .vllm_config .scheduler_config .max_num_seqs == 1 :
403
- model_runner .input_batch = SamplingInputBatch (
404
- max_num_reqs = model_runner .vllm_config .scheduler_config .
405
- max_num_seqs ,
406
- max_model_len = model_runner .vllm_config .model_config .
407
- max_model_len ,
408
- device = model_runner .device ,
409
- pin_memory = model_runner .pin_memory ,
410
- vocab_size = model_runner .vllm_config .model_config .
411
- get_vocab_size (),
412
- )
413
-
414
- model_runner .finish_warmup ()
390
+ model_runner .complete_warmup ()
415
391
416
392
warmup_end_t = time .time ()
417
393
warmup_total_t = warmup_end_t - warmup_start_t
0 commit comments