1
+ import dataclasses
2
+ import weakref
1
3
from dataclasses import dataclass
2
4
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type , Union
3
5
17
19
from vllm .sequence import IntermediateTensors , SequenceGroupMetadata
18
20
from vllm .utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS , make_tensor_with_pad
19
21
from vllm .worker .model_runner_base import (
20
- ModelRunnerBase , ModelRunnerInputBase ,
22
+ ModelRunnerBase , ModelRunnerInputBase , ModelRunnerInputBuilderBase ,
21
23
_add_attn_metadata_broadcastable_dict ,
22
24
_add_sampling_metadata_broadcastable_dict ,
23
25
_init_attn_metadata_from_tensor_dict ,
32
34
33
35
34
36
@dataclass (frozen = True )
35
- class CPUModelInput (ModelRunnerInputBase ):
37
+ class ModelInputForCPU (ModelRunnerInputBase ):
36
38
"""
37
- Used by the CPUModelRunner.
39
+ Base class contains metadata needed for the base model forward pass on CPU
38
40
"""
39
41
input_tokens : Optional [torch .Tensor ] = None
40
42
input_positions : Optional [torch .Tensor ] = None
41
43
attn_metadata : Optional ["AttentionMetadata" ] = None
42
- sampling_metadata : Optional ["SamplingMetadata" ] = None
43
44
multi_modal_kwargs : Optional [BatchedTensorInputs ] = None
44
45
virtual_engine : Optional [int ] = None
46
+ seq_lens : Optional [List [int ]] = None
47
+ query_lens : Optional [List [int ]] = None
45
48
46
49
def as_broadcastable_tensor_dict (
47
50
self ) -> Dict [str , Union [int , torch .Tensor ]]:
@@ -51,88 +54,96 @@ def as_broadcastable_tensor_dict(
51
54
"multi_modal_kwargs" : self .multi_modal_kwargs ,
52
55
}
53
56
_add_attn_metadata_broadcastable_dict (tensor_dict , self .attn_metadata )
54
- _add_sampling_metadata_broadcastable_dict (tensor_dict ,
55
- self .sampling_metadata )
57
+
56
58
return tensor_dict
57
59
58
60
@classmethod
59
61
def from_broadcasted_tensor_dict (
60
- cls : Type ["CPUModelInput" ],
61
- tensor_dict : Dict [str , Any ],
62
- attn_backend : Optional ["AttentionBackend" ] = None
63
- ) -> "CPUModelInput" :
64
- tensor_dict = _init_sampling_metadata_from_tensor_dict (tensor_dict )
62
+ cls : Type ["ModelInputForCPU" ],
63
+ tensor_dict : Dict [str , Any ],
64
+ attn_backend : Optional ["AttentionBackend" ] = None
65
+ ) -> "ModelInputForCPU" :
65
66
if attn_backend is not None :
66
67
tensor_dict = _init_attn_metadata_from_tensor_dict (
67
68
attn_backend , tensor_dict )
68
69
return cls (** tensor_dict )
69
70
70
71
71
- class CPUModelRunner (ModelRunnerBase [CPUModelInput ]):
72
+ @dataclass (frozen = True )
73
+ class ModelInputForCPUWithSamplingMetadata (ModelInputForCPU ):
74
+ """
75
+ Used by the ModelRunner.
76
+ """
77
+ sampling_metadata : Optional ["SamplingMetadata" ] = None
72
78
73
- def __init__ (
74
- self ,
75
- model_config : ModelConfig ,
76
- parallel_config : ParallelConfig ,
77
- scheduler_config : SchedulerConfig ,
78
- device_config : DeviceConfig ,
79
- cache_config : CacheConfig ,
80
- load_config : LoadConfig ,
81
- lora_config : Optional [LoRAConfig ],
82
- kv_cache_dtype : Optional [str ] = "auto" ,
83
- prompt_adapter_config : Optional [PromptAdapterConfig ] = None ,
84
- is_driver_worker : bool = False ,
85
- * args ,
86
- ** kwargs ,
87
- ):
88
- self .model_config = model_config
89
- self .parallel_config = parallel_config
90
- self .scheduler_config = scheduler_config
91
- # Currently, CPU worker doesn't support chunked prefill.
92
- assert self .scheduler_config .chunked_prefill_enabled is False
93
- self .device_config = device_config
94
- self .cache_config = cache_config
95
- self .lora_config = lora_config
96
- self .prompt_adapter_config = prompt_adapter_config
97
- self .load_config = load_config
98
- self .is_driver_worker = is_driver_worker
79
+ def as_broadcastable_tensor_dict (self ) -> Dict [str , Any ]:
80
+ tensor_dict = {
81
+ "input_tokens" : self .input_tokens ,
82
+ "input_positions" : self .input_positions ,
83
+ }
84
+ _add_attn_metadata_broadcastable_dict (tensor_dict , self .attn_metadata )
85
+ _add_sampling_metadata_broadcastable_dict (tensor_dict ,
86
+ self .sampling_metadata )
87
+ return tensor_dict
99
88
100
- self .device = self .device_config .device
89
+ @classmethod
90
+ def from_broadcasted_tensor_dict (
91
+ cls ,
92
+ tensor_dict : Dict [str , Any ],
93
+ attn_backend : Optional ["AttentionBackend" ] = None ,
94
+ ) -> "ModelInputForCPUWithSamplingMetadata" :
95
+ tensor_dict = _init_sampling_metadata_from_tensor_dict (tensor_dict )
96
+ if attn_backend is not None :
97
+ tensor_dict = _init_attn_metadata_from_tensor_dict (
98
+ attn_backend , tensor_dict )
99
+ return cls (** tensor_dict )
101
100
102
- self .kv_cache_dtype = kv_cache_dtype
103
- self .sliding_window = model_config .get_sliding_window ()
104
- self .block_size = cache_config .block_size
105
- self .attn_backend = get_attn_backend (
106
- self .model_config .get_num_attention_heads (self .parallel_config ),
107
- self .model_config .get_head_size (),
108
- self .model_config .get_num_kv_heads (self .parallel_config ),
109
- self .model_config .get_sliding_window (),
110
- self .model_config .dtype ,
111
- self .kv_cache_dtype ,
112
- self .block_size ,
113
- )
114
101
115
- # Multi-modal data support
116
- self .mm_registry = MULTIMODAL_REGISTRY
117
- self .multi_modal_input_mapper = self .mm_registry \
118
- .create_input_mapper (self .model_config )
119
- self .mm_registry .init_mm_limits_per_prompt (self .model_config )
102
+ class ModelInputForCPUBuilder (ModelRunnerInputBuilderBase [ModelInputForCPU ]):
120
103
121
- # Lazy initialization.
122
- self .model : nn .Module # Set after init_Model
104
+ def __init__ (self ,
105
+ runner : "CPUModelRunner" ,
106
+ finished_requests_ids : Optional [List [str ]] = None ) -> None :
107
+ super ().__init__ ()
108
+ self .seq_group_metadata_list : List [SequenceGroupMetadata ] = []
109
+ self .runner = runner
110
+ self .model_input_cls = self .runner ._model_input_cls
111
+ self .attn_backend = self .runner .attn_backend
112
+ self .sliding_window = self .runner .sliding_window
113
+ self .block_size = self .runner .block_size
114
+ self .device = self .runner .device
115
+ self .multi_modal_input_mapper = self .runner .multi_modal_input_mapper
123
116
124
- if self .model_config .is_encoder_decoder_model :
125
- raise NotImplementedError (
126
- STR_NOT_IMPL_ENC_DEC_ERR_STRS ['STR_NOT_IMPL_ENC_DEC_CPU' ])
117
+ def add_seq_group (self , seq_group_metadata : SequenceGroupMetadata ):
118
+ self .seq_group_metadata_list .append (seq_group_metadata )
127
119
128
- def load_model (self ) -> None :
129
- self .model = get_model (model_config = self .model_config ,
130
- load_config = self .load_config ,
131
- device_config = self .device_config ,
132
- lora_config = self .lora_config ,
133
- parallel_config = self .parallel_config ,
134
- scheduler_config = self .scheduler_config ,
135
- cache_config = self .cache_config )
120
+ def build (self ) -> ModelInputForCPU :
121
+ multi_modal_kwargs = None
122
+ # NOTE: We assume that all sequences in the group are all prompts or
123
+ # all decodes.
124
+ is_prompt = self .seq_group_metadata_list [0 ].is_prompt
125
+ # Prepare input tensors.
126
+ if is_prompt :
127
+ (input_tokens , input_positions , attn_metadata , seq_lens ,
128
+ multi_modal_kwargs ) = self ._prepare_prompt (
129
+ self .seq_group_metadata_list )
130
+ else :
131
+ (input_tokens , input_positions ,
132
+ attn_metadata ) = self ._prepare_decode (
133
+ self .seq_group_metadata_list )
134
+ seq_lens = []
135
+
136
+ return self .model_input_cls (
137
+ input_tokens = input_tokens ,
138
+ input_positions = input_positions ,
139
+ attn_metadata = attn_metadata ,
140
+ multi_modal_kwargs = multi_modal_kwargs ,
141
+ # query_lens is not needed if chunked prefill is not
142
+ # supported. Since CPU worker doesn't support chunked prefill
143
+ # just use seq_lens instead.
144
+ seq_lens = seq_lens ,
145
+ query_lens = seq_lens ,
146
+ )
136
147
137
148
def _prepare_prompt (
138
149
self ,
@@ -165,8 +176,7 @@ def _prepare_prompt(
165
176
# is always the first token in the sequence.
166
177
input_positions .extend (list (range (computed_len , seq_len )))
167
178
168
- mm_data = seq_group_metadata .multi_modal_data
169
- if mm_data :
179
+ if (mm_data := seq_group_metadata .multi_modal_data ):
170
180
mm_kwargs = self .multi_modal_input_mapper (mm_data )
171
181
multi_modal_inputs_list .append (mm_kwargs )
172
182
@@ -302,56 +312,130 @@ def _prepare_decode(
302
312
attn_metadata ,
303
313
)
304
314
315
+
316
+ class CPUModelRunner (ModelRunnerBase [ModelInputForCPU ]):
317
+ _model_input_cls : Type [ModelInputForCPUWithSamplingMetadata ] = (
318
+ ModelInputForCPUWithSamplingMetadata )
319
+ _builder_cls : Type [ModelInputForCPUBuilder ] = ModelInputForCPUBuilder
320
+
321
+ def __init__ (
322
+ self ,
323
+ model_config : ModelConfig ,
324
+ parallel_config : ParallelConfig ,
325
+ scheduler_config : SchedulerConfig ,
326
+ device_config : DeviceConfig ,
327
+ cache_config : CacheConfig ,
328
+ load_config : LoadConfig ,
329
+ lora_config : Optional [LoRAConfig ],
330
+ kv_cache_dtype : Optional [str ] = "auto" ,
331
+ prompt_adapter_config : Optional [PromptAdapterConfig ] = None ,
332
+ is_driver_worker : bool = False ,
333
+ * args ,
334
+ ** kwargs ,
335
+ ):
336
+ self .model_config = model_config
337
+ self .parallel_config = parallel_config
338
+ self .scheduler_config = scheduler_config
339
+ # Currently, CPU worker doesn't support chunked prefill.
340
+ assert self .scheduler_config .chunked_prefill_enabled is False
341
+ self .device_config = device_config
342
+ self .cache_config = cache_config
343
+ self .lora_config = lora_config
344
+ self .prompt_adapter_config = prompt_adapter_config
345
+ self .load_config = load_config
346
+ self .is_driver_worker = is_driver_worker
347
+
348
+ self .device = self .device_config .device
349
+
350
+ self .kv_cache_dtype = kv_cache_dtype
351
+ self .sliding_window = model_config .get_sliding_window ()
352
+ self .block_size = cache_config .block_size
353
+ self .attn_backend = get_attn_backend (
354
+ self .model_config .get_num_attention_heads (self .parallel_config ),
355
+ self .model_config .get_head_size (),
356
+ self .model_config .get_num_kv_heads (self .parallel_config ),
357
+ self .model_config .get_sliding_window (),
358
+ self .model_config .dtype ,
359
+ self .kv_cache_dtype ,
360
+ self .block_size ,
361
+ )
362
+
363
+ # Multi-modal data support
364
+ self .mm_registry = MULTIMODAL_REGISTRY
365
+ self .multi_modal_input_mapper = self .mm_registry \
366
+ .create_input_mapper (self .model_config )
367
+ self .mm_registry .init_mm_limits_per_prompt (self .model_config )
368
+
369
+ # Lazy initialization.
370
+ self .model : nn .Module # Set after init_Model
371
+
372
+ if self .model_config .is_encoder_decoder_model :
373
+ raise NotImplementedError (
374
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS ['STR_NOT_IMPL_ENC_DEC_CPU' ])
375
+
376
+ def load_model (self ) -> None :
377
+ self .model = get_model (model_config = self .model_config ,
378
+ load_config = self .load_config ,
379
+ device_config = self .device_config ,
380
+ lora_config = self .lora_config ,
381
+ parallel_config = self .parallel_config ,
382
+ scheduler_config = self .scheduler_config ,
383
+ cache_config = self .cache_config )
384
+
305
385
def make_model_input_from_broadcasted_tensor_dict (
306
386
self ,
307
387
tensor_dict : Dict [str , Any ],
308
- ) -> CPUModelInput :
309
- return CPUModelInput .from_broadcasted_tensor_dict (
388
+ ) -> ModelInputForCPU :
389
+ return ModelInputForCPU .from_broadcasted_tensor_dict (
310
390
tensor_dict ,
311
391
attn_backend = self .attn_backend ,
312
392
)
313
393
394
+ def _prepare_model_input_tensors (
395
+ self ,
396
+ seq_group_metadata_list : List [SequenceGroupMetadata ],
397
+ finished_requests_ids : Optional [List [str ]] = None
398
+ ) -> ModelInputForCPUWithSamplingMetadata :
399
+ """Helper method to prepare the model input based on a given sequence
400
+ group. Prepares metadata needed for the base model forward pass but not
401
+ metadata for possible additional steps, e.g., sampling.
402
+
403
+ """
404
+ builder = self ._builder_cls (weakref .proxy (self ), finished_requests_ids )
405
+ for seq_group_metadata in seq_group_metadata_list :
406
+ builder .add_seq_group (seq_group_metadata )
407
+
408
+ return builder .build () # type: ignore
409
+
314
410
def prepare_model_input (
315
- self ,
316
- seq_group_metadata_list : List [SequenceGroupMetadata ],
317
- virtual_engine : int = 0 ,
318
- finished_requests_ids : Optional [List [str ]] = None
319
- ) -> CPUModelInput :
320
- multi_modal_kwargs = None
321
- # NOTE: We assume that all sequences in the group are all prompts or
322
- # all decodes.
323
- is_prompt = seq_group_metadata_list [0 ].is_prompt
324
- # Prepare input tensors.
325
- if is_prompt :
326
- (input_tokens , input_positions , attn_metadata , seq_lens ,
327
- multi_modal_kwargs
328
- ) = self ._prepare_prompt (seq_group_metadata_list )
329
- else :
330
- (input_tokens , input_positions ,
331
- attn_metadata ) = self ._prepare_decode (seq_group_metadata_list )
332
- seq_lens = []
333
- sampling_metadata = SamplingMetadata .prepare (
334
- seq_group_metadata_list ,
335
- seq_lens ,
336
- # query_lens is not needed if chunked prefill is not
337
- # supported. Since CPU worker doesn't support chunked prefill
338
- # just use seq_lens instead.
339
- seq_lens ,
340
- self .device ,
341
- pin_memory = False ,
342
- generators = self .get_generators (finished_requests_ids ))
343
- return CPUModelInput (
344
- input_tokens = input_tokens ,
345
- input_positions = input_positions ,
346
- attn_metadata = attn_metadata ,
347
- sampling_metadata = sampling_metadata ,
348
- multi_modal_kwargs = multi_modal_kwargs ,
349
- )
411
+ self ,
412
+ seq_group_metadata_list : List [SequenceGroupMetadata ],
413
+ virtual_engine : int = 0 ,
414
+ finished_requests_ids : Optional [List [str ]] = None
415
+ ) -> ModelInputForCPUWithSamplingMetadata :
416
+ """Prepare the model input based on a given sequence group, including
417
+ metadata for the sampling step.
418
+
419
+ """
420
+ model_input = self ._prepare_model_input_tensors (
421
+ seq_group_metadata_list , finished_requests_ids )
422
+ # Sampling metadata is only required for the final pp group
423
+ generators = self .get_generators (finished_requests_ids )
424
+ sampling_metadata = SamplingMetadata .prepare (seq_group_metadata_list ,
425
+ model_input .seq_lens ,
426
+ model_input .query_lens ,
427
+ self .device ,
428
+ pin_memory = False ,
429
+ generators = generators )
430
+
431
+ return dataclasses .replace (model_input ,
432
+ sampling_metadata = sampling_metadata ,
433
+ virtual_engine = virtual_engine )
350
434
351
435
@torch .no_grad ()
352
436
def execute_model (
353
437
self ,
354
- model_input : CPUModelInput ,
438
+ model_input : ModelInputForCPUWithSamplingMetadata ,
355
439
kv_caches : List [torch .Tensor ],
356
440
intermediate_tensors : Optional [IntermediateTensors ] = None ,
357
441
num_steps : int = 1 ,
0 commit comments