Skip to content

[Hardware][PPC64LE] Enable V1 for ppc64le and ARM #20554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
Expand Down Expand Up @@ -1059,7 +1060,6 @@ def create_engine_config(
If VLLM_USE_V1 is specified by the user but the VllmConfig
is incompatible, we raise an error.
"""
from vllm.platforms import current_platform
current_platform.pre_register_and_update()

device_config = DeviceConfig(
Expand All @@ -1086,9 +1086,16 @@ def create_engine_config(
# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context, model_config)
# Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
if current_platform.is_cpu(
) and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC, CpuArchEnum.ARM):
logger.info(
"Chunked prefill is not supported for ARM and POWER CPUs; "
"disabling it for V1 backend.")
self.enable_chunked_prefill = False
Comment on lines +1089 to +1096
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Akashcodes732 Curious how the V1 scheduler supports no chunked prefill for POWER (ppc64le)/ARM CPUs? The chunking logic in scheduler cannot be turned off right?

else:
self._set_default_args_v0(model_config)

assert self.enable_chunked_prefill is not None

if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
Expand Down Expand Up @@ -1205,7 +1212,6 @@ def create_engine_config(
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
raise ValueError("Multi-Step Chunked-Prefill is not supported "
"for pipeline-parallel-size > 1")
from vllm.platforms import current_platform
if current_platform.is_cpu():
logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
"currently not supported for CPUs and has been "
Expand Down Expand Up @@ -1354,7 +1360,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
# Skip this check if we are running on a non-GPU platform,
# or if the device capability is not available
# (e.g. in a Ray actor without GPUs).
from vllm.platforms import current_platform
if (current_platform.is_cuda()
and current_platform.get_device_capability()
and current_platform.get_device_capability().major < 8):
Expand Down Expand Up @@ -1615,7 +1620,6 @@ def _set_default_args_v1(self, usage_context: UsageContext,
# as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs.
from vllm.platforms import current_platform
try:
device_memory = current_platform.get_device_total_memory()
device_name = current_platform.get_device_name().lower()
Expand Down Expand Up @@ -1718,7 +1722,6 @@ def add_cli_args(parser: FlexibleArgumentParser,
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')
from vllm.platforms import current_platform
current_platform.pre_register_and_update(parser)
return parser

Expand Down
5 changes: 3 additions & 2 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,5 +264,6 @@ def default_v1(cls, model_config) -> bool:
"""Returns whether the current platform can use v1 by default for the
supplied model configuration.
"""
return cls.supports_v1(
model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86
arch = cls.get_cpu_architecture()
return (cls.supports_v1(model_config) and arch
in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM))
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable) -> None:
self.runner = runner
self.block_table = block_table

# For reorder
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
Expand Down Expand Up @@ -162,11 +161,14 @@ def build(self, common_prefix_len: int,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(),
seq_lens_tensor=runner.
seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
chunked_prefill=True,
chunked_prefill=self.runner.scheduler_config.
chunked_prefill_enabled,
max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=runner.
Expand Down
64 changes: 61 additions & 3 deletions vllm/v1/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms import CpuArchEnum, current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -43,8 +43,12 @@ def init_device(self):
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
self.local_omp_cpuid = "all"
if omp_cpuids == "auto":
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
)
if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC:
self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes_ppc64le())
else:
self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes())
else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]

Expand Down Expand Up @@ -153,3 +157,57 @@ def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus

def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
"""
Power (ppc64le) specific: Selects a subset of threads per core for
each NUMA node.This is robust to SMT mode (SMT-8, SMT-4, etc)
because the OS only exposes available threads.This maximizes
performance by avoiding oversubscription of logical CPUs on Power.
"""

def select_threads_per_power_core(node_cpu_ids):
return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]

rank_to_cpus = self.local_omp_cpuid
world_size = self.vllm_config.parallel_config.world_size
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
if libnuma_found and psutil_found:
import psutil
from numa import info
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()

node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(sorted(list(node_intersect)))

if world_size > len(node_to_cpus):
logger.error(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually.", world_size,
len(node_to_cpus))
else:
node_cpus_this_rank = node_to_cpus[self.rank]
node_cpus_this_rank = select_threads_per_power_core(
node_cpus_this_rank)
cpu_count_per_numa = len(node_cpus_this_rank)
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_cpus_this_rank[:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
else:
logger.warning(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus
Comment on lines +161 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function get_cpus_id_binding_based_on_numa_nodes_ppc64le largely duplicates the logic in get_cpus_id_binding_based_on_numa_nodes. Refactor the common logic into a single function and conditionally apply the architecture-specific logic to reduce code duplication and improve maintainability.