Skip to content

Commit 35105e7

Browse files
committed
[Core] Use spawn when cuda is already initialized
One condition that we know will be broken with the default multiprocessing method of `fork` is if a user of vllm as a library initializes cuda prior to running vllm. This change detects this case, emits a warning to the log, and force sets the method to `spawn`. Similar code exists elsewhere (for AMD, Intel) to force the use of `spawn` in all cases for those accelerators. We retain the default behavior if the env var is not set and cuda is not initialized. This seems to work fine and avoids potentially breaking code using vllm as a library without protecting their code under `if __name__ == "__main__"`. Signed-off-by: Russell Bryant <[email protected]>
1 parent a522c71 commit 35105e7

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

vllm/executor/multiproc_gpu_executor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from vllm.sequence import ExecuteModelRequest
1616
from vllm.triton_utils import maybe_set_triton_cache_manager
1717
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
18-
get_distributed_init_method, get_open_port,
19-
get_vllm_instance_id, make_async,
18+
cuda_is_initialized, get_distributed_init_method,
19+
get_open_port, get_vllm_instance_id, make_async,
2020
update_environment_variables)
2121

2222
logger = init_logger(__name__)
@@ -122,6 +122,13 @@ def _check_executor_parameters(self):
122122
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
123123
})
124124

125+
if (cuda_is_initialized()
126+
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
127+
logger.warning("CUDA was previously initialized. We must use "
128+
"the `spawn` multiprocessing start method. Setting "
129+
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
130+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
131+
125132
cuda_device_count = cuda_device_count_stateless()
126133
# Use confusing message for more common TP-only case.
127134
assert tensor_parallel_size <= cuda_device_count, (

vllm/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,13 @@ def cuda_device_count_stateless() -> int:
10901090
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
10911091

10921092

1093+
def cuda_is_initialized() -> bool:
1094+
"""Check if CUDA is initialized."""
1095+
if not torch.cuda._is_compiled():
1096+
return False
1097+
return torch.cuda.is_initialized()
1098+
1099+
10931100
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
10941101
"""Make an instance method that weakly references
10951102
its associated instance and no-ops once that

0 commit comments

Comments
 (0)