Skip to content

Commit d153703

Browse files
russellbyoukaichao
andauthored
[Core] Improve choice of Python multiprocessing method (#8823)
Signed-off-by: Russell Bryant <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent cc27644 commit d153703

File tree

4 files changed

+52
-9
lines changed

4 files changed

+52
-9
lines changed

vllm/executor/multiproc_gpu_executor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from vllm.sequence import ExecuteModelRequest
1616
from vllm.triton_utils import maybe_set_triton_cache_manager
1717
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
18-
get_distributed_init_method, get_open_port,
19-
get_vllm_instance_id, make_async,
18+
cuda_is_initialized, get_distributed_init_method,
19+
get_open_port, get_vllm_instance_id, make_async,
2020
update_environment_variables)
2121

2222
logger = init_logger(__name__)
@@ -122,6 +122,13 @@ def _check_executor_parameters(self):
122122
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
123123
})
124124

125+
if (cuda_is_initialized()
126+
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
127+
logger.warning("CUDA was previously initialized. We must use "
128+
"the `spawn` multiprocessing start method. Setting "
129+
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
130+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
131+
125132
cuda_device_count = cuda_device_count_stateless()
126133
# Use confusing message for more common TP-only case.
127134
assert tensor_parallel_size <= cuda_device_count, (

vllm/executor/multiproc_worker_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727

2828
JOIN_TIMEOUT_S = 2
2929

30-
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
31-
mp = multiprocessing.get_context(mp_method)
32-
3330

3431
@dataclass
3532
class Result(Generic[T]):
@@ -77,7 +74,7 @@ class ResultHandler(threading.Thread):
7774

7875
def __init__(self) -> None:
7976
super().__init__(daemon=True)
80-
self.result_queue = mp.Queue()
77+
self.result_queue = get_mp_context().Queue()
8178
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
8279

8380
def run(self):
@@ -147,10 +144,11 @@ class ProcessWorkerWrapper:
147144

148145
def __init__(self, result_handler: ResultHandler,
149146
worker_factory: Callable[[], Any]) -> None:
150-
self._task_queue = mp.Queue()
147+
self.mp = get_mp_context()
148+
self._task_queue = self.mp.Queue()
151149
self.result_queue = result_handler.result_queue
152150
self.tasks = result_handler.tasks
153-
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
151+
self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined]
154152
target=_run_worker_process,
155153
name="VllmWorkerProcess",
156154
kwargs=dict(
@@ -204,7 +202,7 @@ def _run_worker_process(
204202
"""Worker process event loop"""
205203

206204
# Add process-specific prefix to stdout and stderr
207-
process_name = mp.current_process().name
205+
process_name = get_mp_context().current_process().name
208206
pid = os.getpid()
209207
_add_prefix(sys.stdout, process_name, pid)
210208
_add_prefix(sys.stderr, process_name, pid)
@@ -269,3 +267,8 @@ def write_with_prefix(s: str):
269267

270268
file.start_new_line = True # type: ignore[attr-defined]
271269
file.write = write_with_prefix # type: ignore[method-assign]
270+
271+
272+
def get_mp_context():
273+
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
274+
return multiprocessing.get_context(mp_method)

vllm/scripts.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from vllm.engine.arg_utils import EngineArgs
1313
from vllm.entrypoints.openai.api_server import run_server
1414
from vllm.entrypoints.openai.cli_args import make_arg_parser
15+
from vllm.logger import init_logger
1516
from vllm.utils import FlexibleArgumentParser
1617

18+
logger = init_logger(__name__)
19+
1720

1821
def register_signal_handlers():
1922

@@ -114,7 +117,30 @@ def _add_query_options(
114117
return parser
115118

116119

120+
def env_setup():
121+
# The safest multiprocessing method is `spawn`, as the default `fork` method
122+
# is not compatible with some accelerators. The default method will be
123+
# changing in future versions of Python, so we should use it explicitly when
124+
# possible.
125+
#
126+
# We only set it here in the CLI entrypoint, because changing to `spawn`
127+
# could break some existing code using vLLM as a library. `spawn` will cause
128+
# unexpected behavior if the code is not protected by
129+
# `if __name__ == "__main__":`.
130+
#
131+
# References:
132+
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
133+
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
134+
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
135+
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
136+
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
137+
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
138+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
139+
140+
117141
def main():
142+
env_setup()
143+
118144
parser = FlexibleArgumentParser(description="vLLM CLI")
119145
subparsers = parser.add_subparsers(required=True)
120146

vllm/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,13 @@ def cuda_device_count_stateless() -> int:
10911091
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
10921092

10931093

1094+
def cuda_is_initialized() -> bool:
1095+
"""Check if CUDA is initialized."""
1096+
if not torch.cuda._is_compiled():
1097+
return False
1098+
return torch.cuda.is_initialized()
1099+
1100+
10941101
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
10951102
"""Make an instance method that weakly references
10961103
its associated instance and no-ops once that

0 commit comments

Comments
 (0)