Skip to content

Commit fbbb438

Browse files
Isotr0pyLeiWang1999
authored andcommitted
[Hardware][CPU] Refactor CPU model runner (vllm-project#8729)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent c2aad56 commit fbbb438

File tree

1 file changed

+193
-109
lines changed

1 file changed

+193
-109
lines changed

vllm/worker/cpu_model_runner.py

Lines changed: 193 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import dataclasses
2+
import weakref
13
from dataclasses import dataclass
24
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
35

@@ -17,7 +19,7 @@
1719
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
1820
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
1921
from vllm.worker.model_runner_base import (
20-
ModelRunnerBase, ModelRunnerInputBase,
22+
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
2123
_add_attn_metadata_broadcastable_dict,
2224
_add_sampling_metadata_broadcastable_dict,
2325
_init_attn_metadata_from_tensor_dict,
@@ -32,16 +34,17 @@
3234

3335

3436
@dataclass(frozen=True)
35-
class CPUModelInput(ModelRunnerInputBase):
37+
class ModelInputForCPU(ModelRunnerInputBase):
3638
"""
37-
Used by the CPUModelRunner.
39+
Base class contains metadata needed for the base model forward pass on CPU
3840
"""
3941
input_tokens: Optional[torch.Tensor] = None
4042
input_positions: Optional[torch.Tensor] = None
4143
attn_metadata: Optional["AttentionMetadata"] = None
42-
sampling_metadata: Optional["SamplingMetadata"] = None
4344
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
4445
virtual_engine: Optional[int] = None
46+
seq_lens: Optional[List[int]] = None
47+
query_lens: Optional[List[int]] = None
4548

4649
def as_broadcastable_tensor_dict(
4750
self) -> Dict[str, Union[int, torch.Tensor]]:
@@ -51,88 +54,96 @@ def as_broadcastable_tensor_dict(
5154
"multi_modal_kwargs": self.multi_modal_kwargs,
5255
}
5356
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
54-
_add_sampling_metadata_broadcastable_dict(tensor_dict,
55-
self.sampling_metadata)
57+
5658
return tensor_dict
5759

5860
@classmethod
5961
def from_broadcasted_tensor_dict(
60-
cls: Type["CPUModelInput"],
61-
tensor_dict: Dict[str, Any],
62-
attn_backend: Optional["AttentionBackend"] = None
63-
) -> "CPUModelInput":
64-
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
62+
cls: Type["ModelInputForCPU"],
63+
tensor_dict: Dict[str, Any],
64+
attn_backend: Optional["AttentionBackend"] = None
65+
) -> "ModelInputForCPU":
6566
if attn_backend is not None:
6667
tensor_dict = _init_attn_metadata_from_tensor_dict(
6768
attn_backend, tensor_dict)
6869
return cls(**tensor_dict)
6970

7071

71-
class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
72+
@dataclass(frozen=True)
73+
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
74+
"""
75+
Used by the ModelRunner.
76+
"""
77+
sampling_metadata: Optional["SamplingMetadata"] = None
7278

73-
def __init__(
74-
self,
75-
model_config: ModelConfig,
76-
parallel_config: ParallelConfig,
77-
scheduler_config: SchedulerConfig,
78-
device_config: DeviceConfig,
79-
cache_config: CacheConfig,
80-
load_config: LoadConfig,
81-
lora_config: Optional[LoRAConfig],
82-
kv_cache_dtype: Optional[str] = "auto",
83-
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
84-
is_driver_worker: bool = False,
85-
*args,
86-
**kwargs,
87-
):
88-
self.model_config = model_config
89-
self.parallel_config = parallel_config
90-
self.scheduler_config = scheduler_config
91-
# Currently, CPU worker doesn't support chunked prefill.
92-
assert self.scheduler_config.chunked_prefill_enabled is False
93-
self.device_config = device_config
94-
self.cache_config = cache_config
95-
self.lora_config = lora_config
96-
self.prompt_adapter_config = prompt_adapter_config
97-
self.load_config = load_config
98-
self.is_driver_worker = is_driver_worker
79+
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
80+
tensor_dict = {
81+
"input_tokens": self.input_tokens,
82+
"input_positions": self.input_positions,
83+
}
84+
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
85+
_add_sampling_metadata_broadcastable_dict(tensor_dict,
86+
self.sampling_metadata)
87+
return tensor_dict
9988

100-
self.device = self.device_config.device
89+
@classmethod
90+
def from_broadcasted_tensor_dict(
91+
cls,
92+
tensor_dict: Dict[str, Any],
93+
attn_backend: Optional["AttentionBackend"] = None,
94+
) -> "ModelInputForCPUWithSamplingMetadata":
95+
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
96+
if attn_backend is not None:
97+
tensor_dict = _init_attn_metadata_from_tensor_dict(
98+
attn_backend, tensor_dict)
99+
return cls(**tensor_dict)
101100

102-
self.kv_cache_dtype = kv_cache_dtype
103-
self.sliding_window = model_config.get_sliding_window()
104-
self.block_size = cache_config.block_size
105-
self.attn_backend = get_attn_backend(
106-
self.model_config.get_num_attention_heads(self.parallel_config),
107-
self.model_config.get_head_size(),
108-
self.model_config.get_num_kv_heads(self.parallel_config),
109-
self.model_config.get_sliding_window(),
110-
self.model_config.dtype,
111-
self.kv_cache_dtype,
112-
self.block_size,
113-
)
114101

115-
# Multi-modal data support
116-
self.mm_registry = MULTIMODAL_REGISTRY
117-
self.multi_modal_input_mapper = self.mm_registry \
118-
.create_input_mapper(self.model_config)
119-
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
102+
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
120103

121-
# Lazy initialization.
122-
self.model: nn.Module # Set after init_Model
104+
def __init__(self,
105+
runner: "CPUModelRunner",
106+
finished_requests_ids: Optional[List[str]] = None) -> None:
107+
super().__init__()
108+
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
109+
self.runner = runner
110+
self.model_input_cls = self.runner._model_input_cls
111+
self.attn_backend = self.runner.attn_backend
112+
self.sliding_window = self.runner.sliding_window
113+
self.block_size = self.runner.block_size
114+
self.device = self.runner.device
115+
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
123116

124-
if self.model_config.is_encoder_decoder_model:
125-
raise NotImplementedError(
126-
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
117+
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
118+
self.seq_group_metadata_list.append(seq_group_metadata)
127119

128-
def load_model(self) -> None:
129-
self.model = get_model(model_config=self.model_config,
130-
load_config=self.load_config,
131-
device_config=self.device_config,
132-
lora_config=self.lora_config,
133-
parallel_config=self.parallel_config,
134-
scheduler_config=self.scheduler_config,
135-
cache_config=self.cache_config)
120+
def build(self) -> ModelInputForCPU:
121+
multi_modal_kwargs = None
122+
# NOTE: We assume that all sequences in the group are all prompts or
123+
# all decodes.
124+
is_prompt = self.seq_group_metadata_list[0].is_prompt
125+
# Prepare input tensors.
126+
if is_prompt:
127+
(input_tokens, input_positions, attn_metadata, seq_lens,
128+
multi_modal_kwargs) = self._prepare_prompt(
129+
self.seq_group_metadata_list)
130+
else:
131+
(input_tokens, input_positions,
132+
attn_metadata) = self._prepare_decode(
133+
self.seq_group_metadata_list)
134+
seq_lens = []
135+
136+
return self.model_input_cls(
137+
input_tokens=input_tokens,
138+
input_positions=input_positions,
139+
attn_metadata=attn_metadata,
140+
multi_modal_kwargs=multi_modal_kwargs,
141+
# query_lens is not needed if chunked prefill is not
142+
# supported. Since CPU worker doesn't support chunked prefill
143+
# just use seq_lens instead.
144+
seq_lens=seq_lens,
145+
query_lens=seq_lens,
146+
)
136147

137148
def _prepare_prompt(
138149
self,
@@ -165,8 +176,7 @@ def _prepare_prompt(
165176
# is always the first token in the sequence.
166177
input_positions.extend(list(range(computed_len, seq_len)))
167178

168-
mm_data = seq_group_metadata.multi_modal_data
169-
if mm_data:
179+
if (mm_data := seq_group_metadata.multi_modal_data):
170180
mm_kwargs = self.multi_modal_input_mapper(mm_data)
171181
multi_modal_inputs_list.append(mm_kwargs)
172182

@@ -302,56 +312,130 @@ def _prepare_decode(
302312
attn_metadata,
303313
)
304314

315+
316+
class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
317+
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
318+
ModelInputForCPUWithSamplingMetadata)
319+
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
320+
321+
def __init__(
322+
self,
323+
model_config: ModelConfig,
324+
parallel_config: ParallelConfig,
325+
scheduler_config: SchedulerConfig,
326+
device_config: DeviceConfig,
327+
cache_config: CacheConfig,
328+
load_config: LoadConfig,
329+
lora_config: Optional[LoRAConfig],
330+
kv_cache_dtype: Optional[str] = "auto",
331+
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
332+
is_driver_worker: bool = False,
333+
*args,
334+
**kwargs,
335+
):
336+
self.model_config = model_config
337+
self.parallel_config = parallel_config
338+
self.scheduler_config = scheduler_config
339+
# Currently, CPU worker doesn't support chunked prefill.
340+
assert self.scheduler_config.chunked_prefill_enabled is False
341+
self.device_config = device_config
342+
self.cache_config = cache_config
343+
self.lora_config = lora_config
344+
self.prompt_adapter_config = prompt_adapter_config
345+
self.load_config = load_config
346+
self.is_driver_worker = is_driver_worker
347+
348+
self.device = self.device_config.device
349+
350+
self.kv_cache_dtype = kv_cache_dtype
351+
self.sliding_window = model_config.get_sliding_window()
352+
self.block_size = cache_config.block_size
353+
self.attn_backend = get_attn_backend(
354+
self.model_config.get_num_attention_heads(self.parallel_config),
355+
self.model_config.get_head_size(),
356+
self.model_config.get_num_kv_heads(self.parallel_config),
357+
self.model_config.get_sliding_window(),
358+
self.model_config.dtype,
359+
self.kv_cache_dtype,
360+
self.block_size,
361+
)
362+
363+
# Multi-modal data support
364+
self.mm_registry = MULTIMODAL_REGISTRY
365+
self.multi_modal_input_mapper = self.mm_registry \
366+
.create_input_mapper(self.model_config)
367+
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
368+
369+
# Lazy initialization.
370+
self.model: nn.Module # Set after init_Model
371+
372+
if self.model_config.is_encoder_decoder_model:
373+
raise NotImplementedError(
374+
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
375+
376+
def load_model(self) -> None:
377+
self.model = get_model(model_config=self.model_config,
378+
load_config=self.load_config,
379+
device_config=self.device_config,
380+
lora_config=self.lora_config,
381+
parallel_config=self.parallel_config,
382+
scheduler_config=self.scheduler_config,
383+
cache_config=self.cache_config)
384+
305385
def make_model_input_from_broadcasted_tensor_dict(
306386
self,
307387
tensor_dict: Dict[str, Any],
308-
) -> CPUModelInput:
309-
return CPUModelInput.from_broadcasted_tensor_dict(
388+
) -> ModelInputForCPU:
389+
return ModelInputForCPU.from_broadcasted_tensor_dict(
310390
tensor_dict,
311391
attn_backend=self.attn_backend,
312392
)
313393

394+
def _prepare_model_input_tensors(
395+
self,
396+
seq_group_metadata_list: List[SequenceGroupMetadata],
397+
finished_requests_ids: Optional[List[str]] = None
398+
) -> ModelInputForCPUWithSamplingMetadata:
399+
"""Helper method to prepare the model input based on a given sequence
400+
group. Prepares metadata needed for the base model forward pass but not
401+
metadata for possible additional steps, e.g., sampling.
402+
403+
"""
404+
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
405+
for seq_group_metadata in seq_group_metadata_list:
406+
builder.add_seq_group(seq_group_metadata)
407+
408+
return builder.build() # type: ignore
409+
314410
def prepare_model_input(
315-
self,
316-
seq_group_metadata_list: List[SequenceGroupMetadata],
317-
virtual_engine: int = 0,
318-
finished_requests_ids: Optional[List[str]] = None
319-
) -> CPUModelInput:
320-
multi_modal_kwargs = None
321-
# NOTE: We assume that all sequences in the group are all prompts or
322-
# all decodes.
323-
is_prompt = seq_group_metadata_list[0].is_prompt
324-
# Prepare input tensors.
325-
if is_prompt:
326-
(input_tokens, input_positions, attn_metadata, seq_lens,
327-
multi_modal_kwargs
328-
) = self._prepare_prompt(seq_group_metadata_list)
329-
else:
330-
(input_tokens, input_positions,
331-
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
332-
seq_lens = []
333-
sampling_metadata = SamplingMetadata.prepare(
334-
seq_group_metadata_list,
335-
seq_lens,
336-
# query_lens is not needed if chunked prefill is not
337-
# supported. Since CPU worker doesn't support chunked prefill
338-
# just use seq_lens instead.
339-
seq_lens,
340-
self.device,
341-
pin_memory=False,
342-
generators=self.get_generators(finished_requests_ids))
343-
return CPUModelInput(
344-
input_tokens=input_tokens,
345-
input_positions=input_positions,
346-
attn_metadata=attn_metadata,
347-
sampling_metadata=sampling_metadata,
348-
multi_modal_kwargs=multi_modal_kwargs,
349-
)
411+
self,
412+
seq_group_metadata_list: List[SequenceGroupMetadata],
413+
virtual_engine: int = 0,
414+
finished_requests_ids: Optional[List[str]] = None
415+
) -> ModelInputForCPUWithSamplingMetadata:
416+
"""Prepare the model input based on a given sequence group, including
417+
metadata for the sampling step.
418+
419+
"""
420+
model_input = self._prepare_model_input_tensors(
421+
seq_group_metadata_list, finished_requests_ids)
422+
# Sampling metadata is only required for the final pp group
423+
generators = self.get_generators(finished_requests_ids)
424+
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
425+
model_input.seq_lens,
426+
model_input.query_lens,
427+
self.device,
428+
pin_memory=False,
429+
generators=generators)
430+
431+
return dataclasses.replace(model_input,
432+
sampling_metadata=sampling_metadata,
433+
virtual_engine=virtual_engine)
350434

351435
@torch.no_grad()
352436
def execute_model(
353437
self,
354-
model_input: CPUModelInput,
438+
model_input: ModelInputForCPUWithSamplingMetadata,
355439
kv_caches: List[torch.Tensor],
356440
intermediate_tensors: Optional[IntermediateTensors] = None,
357441
num_steps: int = 1,

0 commit comments

Comments
 (0)