Skip to content

Commit 962abf1

Browse files
[Priority merge] NewRequestData parameter introduced in vllm upstream (#245)
Temporary hack until the parameter makes it to a new release version. Needs to be merged first for the tests on the other PRs to pass. (PS: this was actually the error after fixing the merge conflict in PR #240, which had nothing to do with the conflict) --------- Signed-off-by: Sophie du Couédic <[email protected]> Co-authored-by: Yannick Schnider <[email protected]>
1 parent 2c295c8 commit 962abf1

File tree

3 files changed

+57
-43
lines changed

3 files changed

+57
-43
lines changed

tests/spyre_util.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
import math
32
import os
43
import subprocess
@@ -540,35 +539,24 @@ def create_random_request(
540539
request_id: int, num_tokens: int,
541540
sampling_params: SamplingParams) -> EngineCoreRequest:
542541

543-
# Temporary until 'data_parallel_rank' parameter makes it to
544-
# a release version in vllm
545-
if "data_parallel_rank" in [
546-
x[0] for x in inspect.getmembers(EngineCoreRequest)
547-
]:
548-
return EngineCoreRequest(
549-
request_id=str(request_id),
550-
prompt_token_ids=[request_id] * num_tokens,
551-
mm_inputs=None,
552-
mm_hashes=None,
553-
mm_placeholders=None,
554-
sampling_params=sampling_params,
555-
eos_token_id=None,
556-
arrival_time=0,
557-
lora_request=None,
558-
cache_salt=None,
559-
data_parallel_rank=None,
560-
)
561-
else:
562-
return EngineCoreRequest(request_id=str(request_id),
563-
prompt_token_ids=[request_id] * num_tokens,
564-
mm_inputs=None,
565-
mm_hashes=None,
566-
mm_placeholders=None,
567-
sampling_params=sampling_params,
568-
eos_token_id=None,
569-
arrival_time=0,
570-
lora_request=None,
571-
cache_salt=None)
542+
# Temporary until these parameters make it to a release version in vllm
543+
extra_kwargs: dict[str, Any] = {}
544+
if "data_parallel_rank" in EngineCoreRequest.__annotations__:
545+
extra_kwargs["data_parallel_rank"] = None
546+
if "pooling_params" in EngineCoreRequest.__annotations__:
547+
extra_kwargs["pooling_params"] = None
548+
549+
return EngineCoreRequest(request_id=str(request_id),
550+
prompt_token_ids=[request_id] * num_tokens,
551+
mm_inputs=None,
552+
mm_hashes=None,
553+
mm_placeholders=None,
554+
sampling_params=sampling_params,
555+
eos_token_id=None,
556+
arrival_time=0,
557+
lora_request=None,
558+
cache_salt=None,
559+
**extra_kwargs)
572560

573561

574562
def skip_unsupported_tp_size(size: int):

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import deque
33
from collections.abc import Iterable
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, Any, Optional
66

77
import torch
88
from torch import nn
@@ -430,6 +430,12 @@ def execute_model(
430430

431431
t0 = time.time()
432432

433+
# TODO temporary until 'pooler_output' makes it to a release version
434+
# in vllm
435+
extra_kwargs: dict[str, Any] = {}
436+
if "pooler_output" in ModelRunnerOutput.__dataclass_fields__:
437+
extra_kwargs["pooler_output"] = None
438+
433439
# TODO: change to EMPTY_MODEL_RUNNER_OUTPUT, right now this
434440
# will be a breaking change, or clumsy to make retrocompatible
435441
# with conditional import
@@ -442,6 +448,7 @@ def execute_model(
442448
spec_token_ids=None,
443449
logprobs=None,
444450
prompt_logprobs_dict={},
451+
**extra_kwargs,
445452
)
446453

447454
self._update_states(scheduler_output)
@@ -490,6 +497,7 @@ def execute_model(
490497
req_id: None
491498
for req_id in self.input_batch.req_id_to_index
492499
}, # TODO(wallas?): prompt logprobs too
500+
**extra_kwargs,
493501
)
494502
return model_output
495503

@@ -937,21 +945,27 @@ def execute_model(
937945

938946
t0 = time.time()
939947

948+
# TODO temporary until 'pooler_output' makes it to a release version
949+
# in vllm
950+
extra_kwargs: dict[str, Any] = {}
951+
if "pooler_output" in CBSpyreModelRunnerOutput.__dataclass_fields__:
952+
extra_kwargs["pooler_output"] = None
953+
940954
self._update_states(scheduler_output)
941955
# TODO: change to EMPTY_MODEL_RUNNER_OUTPUT, right now this
942956
# will be a breaking change, or clumsy to make retrocompatible
943957
# with conditional import
944958
if not scheduler_output.total_num_scheduled_tokens:
959+
945960
# Return empty ModelRunnerOuptut if there's no work to do.
946-
return CBSpyreModelRunnerOutput(
947-
req_ids=[],
948-
req_id_to_index={},
949-
sampled_token_ids=[],
950-
spec_token_ids=None,
951-
logprobs=None,
952-
prompt_logprobs_dict={},
953-
tkv=0,
954-
)
961+
return CBSpyreModelRunnerOutput(req_ids=[],
962+
req_id_to_index={},
963+
sampled_token_ids=[],
964+
spec_token_ids=None,
965+
logprobs=None,
966+
prompt_logprobs_dict={},
967+
tkv=0,
968+
**extra_kwargs)
955969

956970
model_input = self.prepare_model_input(scheduler_output)
957971

@@ -1037,5 +1051,6 @@ def execute_model(
10371051
for req_id in req_ids
10381052
}, # TODO(wallas?): prompt logprobs too
10391053
tkv=self.tkv,
1054+
**extra_kwargs,
10401055
)
10411056
return model_output

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import platform
66
import signal
77
import time
8-
from typing import Optional, Union, cast
8+
from typing import Any, Optional, Union, cast
99

1010
import torch
1111
import torch.distributed as dist
@@ -322,6 +322,11 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
322322
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
323323
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
324324

325+
# TODO temporary until 'pooling_params' makes it to a release version
326+
# in vllm
327+
extra_kwargs: dict[str, Any] = {}
328+
if "pooling_params" in NewRequestData.__dataclass_fields__:
329+
extra_kwargs["pooling_params"] = None
325330
dummy_requests = [
326331
NewRequestData(
327332
req_id="warmup-%d" % (i),
@@ -333,7 +338,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
333338
block_ids=[0], # not actually used
334339
num_computed_tokens=0,
335340
lora_request=None,
336-
) for i in range(batch_size)
341+
**extra_kwargs) for i in range(batch_size)
337342
]
338343

339344
for i, req in enumerate(dummy_requests):
@@ -487,6 +492,12 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
487492
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
488493
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
489494

495+
# TODO temporary until 'pooling_params' makes it to a release version
496+
# in vllm
497+
extra_kwargs: dict[str, Any] = {}
498+
if "pooling_params" in NewRequestData.__dataclass_fields__:
499+
extra_kwargs["pooling_params"] = None
500+
490501
# Set up dummy requests for prefill steps
491502
dummy_requests = [
492503
NewRequestData(
@@ -499,7 +510,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
499510
block_ids=[0],
500511
num_computed_tokens=0,
501512
lora_request=None,
502-
) for i in range(batch_size)
513+
**extra_kwargs) for i in range(batch_size)
503514
]
504515

505516
# Set up dummy cached_requests for decode steps

0 commit comments

Comments
 (0)