1
- import inspect
2
1
from typing import (TYPE_CHECKING , ClassVar , Dict , List , Literal , Optional ,
3
2
Protocol , Type , Union , overload , runtime_checkable )
4
3
5
4
import torch
6
5
from typing_extensions import TypeIs
7
6
8
7
from vllm .logger import init_logger
8
+ from vllm .utils import supports_kw
9
9
10
10
if TYPE_CHECKING :
11
- from vllm .attention import AttentionMetadata
12
11
from vllm .config import LoRAConfig , MultiModalConfig , SchedulerConfig
13
12
from vllm .sequence import IntermediateTensors
14
13
@@ -142,9 +141,7 @@ def supports_lora(
142
141
return result
143
142
144
143
145
- def _supports_lora (
146
- model : Union [Type [object ], object ],
147
- ) -> Union [TypeIs [Type [SupportsLoRA ]], TypeIs [SupportsLoRA ]]:
144
+ def _supports_lora (model : Union [Type [object ], object ]) -> bool :
148
145
if isinstance (model , type ):
149
146
return isinstance (model , _SupportsLoRAType )
150
147
@@ -175,10 +172,7 @@ def make_empty_intermediate_tensors(
175
172
176
173
def forward (
177
174
self ,
178
- input_ids : torch .Tensor ,
179
- position_ids : torch .Tensor ,
180
- kv_caches : List [torch .Tensor ],
181
- attn_metadata : "AttentionMetadata" ,
175
+ * ,
182
176
intermediate_tensors : Optional ["IntermediateTensors" ],
183
177
) -> Union [torch .Tensor , "IntermediateTensors" ]:
184
178
"""
@@ -205,10 +199,7 @@ def make_empty_intermediate_tensors(
205
199
206
200
def forward (
207
201
self ,
208
- input_ids : torch .Tensor ,
209
- position_ids : torch .Tensor ,
210
- kv_caches : List [torch .Tensor ],
211
- attn_metadata : "AttentionMetadata" ,
202
+ * ,
212
203
intermediate_tensors : Optional ["IntermediateTensors" ],
213
204
) -> Union [torch .Tensor , "IntermediateTensors" ]:
214
205
...
@@ -257,24 +248,19 @@ def supports_pp(
257
248
return supports_attributes and supports_inspect
258
249
259
250
260
- def _supports_pp_attributes (
261
- model : Union [Type [object ], object ],
262
- ) -> Union [bool , TypeIs [Type [SupportsPP ]], TypeIs [SupportsPP ]]:
251
+ def _supports_pp_attributes (model : Union [Type [object ], object ]) -> bool :
263
252
if isinstance (model , type ):
264
253
return isinstance (model , _SupportsPPType )
265
254
266
255
return isinstance (model , SupportsPP )
267
256
268
257
269
- def _supports_pp_inspect (
270
- model : Union [Type [object ], object ],
271
- ) -> Union [bool , TypeIs [Type [SupportsPP ]], TypeIs [SupportsPP ]]:
258
+ def _supports_pp_inspect (model : Union [Type [object ], object ]) -> bool :
272
259
model_forward = getattr (model , "forward" , None )
273
260
if not callable (model_forward ):
274
261
return False
275
262
276
- forward_params = inspect .signature (model_forward ).parameters
277
- return "intermediate_tensors" in forward_params
263
+ return supports_kw (model_forward , "intermediate_tensors" )
278
264
279
265
280
266
@runtime_checkable
0 commit comments