Skip to content

Commit aa3874e

Browse files
♻️ fix vllm:main (#341)
It seems `model_config.task` is deprecated, instead from what I understand we can use `model_config.supported_tasks`, which is initialized when an llm engine is instantiated: https://github.com/vllm-project/vllm/pull/21470/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953R705-R706 To maintain backward compatibility, it's a bit tricky since: - Earlier version had `model_config.task` pointing to the task and `model_config.supported_tasks` as a list of all tasks which could contain more than 1 task ``` model_config.task : generate model_config.supported_tasks : {'embed', 'reward', 'generate', 'classify'} ``` - Latest `main` now populates `model_config.supported_tasks` as the only task the model supports. ``` model_config.task : None model_config.supported_tasks : ['generate'] ``` --------- Signed-off-by: Prashant Gupta <[email protected]> Signed-off-by: Max de Bayser <[email protected]>
1 parent 4d98151 commit aa3874e

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

vllm_spyre/platform.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import sys
23

34
# When running this plugin on a Mac, we assume it's for local development
@@ -80,8 +81,13 @@ class SpyrePlatform(Platform):
8081
def device_type(cls):
8182
# TODO: temporary hack while BertModels
8283
# inherit SupportsV0Only in vllm upstream.
84+
import vllm.model_executor.models as me_models
8385
from vllm.config import ModelConfig
84-
ModelConfig.is_v1_compatible = is_v1_compatible
86+
87+
# no need to patch after the model_config change
88+
if 'model_config' not in \
89+
inspect.getfullargspec(me_models.ModelRegistry.is_v1_compatible).args:
90+
ModelConfig.is_v1_compatible = is_v1_compatible
8591
return cls._device_type
8692

8793
@classmethod
@@ -106,11 +112,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
106112
if scheduler_config.is_multi_step:
107113
raise NotImplementedError
108114

109-
is_decoder = model_config.task == "generate"
110-
is_pooling = model_config.task == "embed"
111-
if model_config.task == "auto":
112-
is_pooling = "embed" in model_config.supported_tasks
113-
is_decoder = "generate" in model_config.supported_tasks
115+
# Can be simplified after the model_config change from vllm:main
116+
is_decoder = model_config.task == "generate" \
117+
if model_config.task \
118+
else "generate" in model_config.supported_tasks
119+
120+
is_pooling = model_config.task == "embed" \
121+
if model_config.task \
122+
else "embed" in model_config.supported_tasks
114123

115124
if is_decoder and not envs.VLLM_USE_V1:
116125
raise ValueError("Decoder models are only supported on v1")

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ class SpyreWorker(WorkerBaseV1):
5757
"""A worker class that executes the model on a group of Spyre cores.
5858
"""
5959

60+
@property
61+
def is_pooling(self) -> bool:
62+
return self.model_config.task == "embed" \
63+
if self.model_config.task else \
64+
"embed" in self.model_config.supported_tasks
65+
66+
@property
67+
def is_decoder(self) -> bool:
68+
return self.model_config.task == "generate" \
69+
if self.model_config.task else \
70+
"generate" in self.model_config.supported_tasks
71+
6072
def get_kv_cache_spec(self) -> KVCacheSpec:
6173
"""Get specifications for KV cache implementation.
6274
@@ -85,7 +97,7 @@ def compile_or_warm_up_model(self) -> None:
8597
(s["prompt_length"], s["new_tokens"], s["batch_size"])
8698
for s in self.spyre_warmup_shapes
8799
]):
88-
if self.model_config.task != "embed":
100+
if not self.is_pooling:
89101
# TODO: remove if spyre supports
90102
# lower number of output tokens
91103
assert num_decode_tokens >= 2, (
@@ -168,7 +180,7 @@ def __init__(
168180
self.model_runner: \
169181
Union[StaticBatchingSpyreModelRunner,
170182
ContinuousBatchingSpyreModelRunner, SpyrePoolingModelRunner]
171-
if self.model_config.task == "embed":
183+
if self.is_pooling:
172184
self.model_runner = SpyrePoolingModelRunner(
173185
self.vllm_config, self.is_driver_worker)
174186
self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes(
@@ -457,7 +469,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
457469
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
458470

459471
sampling_params, pooling_params = None, None
460-
if self.model_config.task != "embed":
472+
if not self.is_pooling:
461473
sampling_params = SamplingParams(max_tokens=num_decode_tokens)
462474
else:
463475
pooling_params = PoolingParams()

vllm_spyre/worker/spyre_worker.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ class SpyreWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
4242
"""A worker class that executes the model on a group of Spyre cores.
4343
"""
4444

45+
@property
46+
def is_pooling(self) -> bool:
47+
return self.model_config.task == "embed" \
48+
if self.model_config.task else \
49+
"embed" in self.model_config.supported_tasks
50+
51+
@property
52+
def is_decoder(self) -> bool:
53+
return self.model_config.task == "generate" \
54+
if self.model_config.task else \
55+
"generate" in self.model_config.supported_tasks
56+
4557
def __init__(
4658
self,
4759
vllm_config: VllmConfig,
@@ -64,7 +76,7 @@ def __init__(
6476
from vllm.utils import init_cached_hf_modules
6577
init_cached_hf_modules()
6678

67-
if self.model_config.task == "embed":
79+
if self.is_pooling:
6880
self.model_runner: SpyreModelRunner = SpyreEmbeddingModelRunner(
6981
self.model_config, self.parallel_config, self.scheduler_config,
7082
self.device_config, self.is_driver_worker)
@@ -205,7 +217,7 @@ def load_model(self):
205217
(s["prompt_length"], s["new_tokens"], s["batch_size"])
206218
for s in self.spyre_warmup_shapes
207219
]):
208-
if self.model_config.task != "embed":
220+
if not self.is_pooling:
209221
# TODO: remove if spyre supports
210222
# lower number of output tokens
211223
assert num_decode_tokens >= 2, (

0 commit comments

Comments
 (0)