Skip to content

Commit d0cba9a

Browse files
maxdebayserrobertgshaw2-redhat
authored andcommitted
[BugFix] Fix frontend multiprocessing hang (vllm-project#7217)
Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Signed-off-by: Alvant <[email protected]>
1 parent 032d0f3 commit d0cba9a

File tree

3 files changed

+67
-5
lines changed

3 files changed

+67
-5
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Any
2+
3+
import pytest
4+
5+
from vllm.engine.async_llm_engine import AsyncLLMEngine
6+
from vllm.entrypoints.openai.api_server import build_async_engine_client
7+
from vllm.entrypoints.openai.cli_args import make_arg_parser
8+
from vllm.utils import FlexibleArgumentParser
9+
10+
11+
def crashing_from_engine_args(
12+
cls,
13+
engine_args: Any = None,
14+
start_engine_loop: Any = None,
15+
usage_context: Any = None,
16+
stat_loggers: Any = None,
17+
) -> "AsyncLLMEngine":
18+
raise Exception("foo")
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_mp_crash_detection(monkeypatch):
23+
24+
with pytest.raises(RuntimeError) as excinfo, monkeypatch.context() as m:
25+
m.setattr(AsyncLLMEngine, "from_engine_args",
26+
crashing_from_engine_args)
27+
parser = FlexibleArgumentParser(
28+
description="vLLM's remote OpenAI server.")
29+
parser = make_arg_parser(parser)
30+
args = parser.parse_args([])
31+
32+
async with build_async_engine_client(args):
33+
pass
34+
assert "The server process died before responding to the readiness probe"\
35+
in str(excinfo.value)

vllm/entrypoints/openai/api_server.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,18 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
120120

121121
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
122122
async_engine_client = AsyncEngineRPCClient(rpc_path)
123-
await async_engine_client.setup()
124123

125124
try:
125+
while True:
126+
try:
127+
await async_engine_client.setup()
128+
break
129+
except TimeoutError as e:
130+
if not rpc_server_process.is_alive():
131+
raise RuntimeError(
132+
"The server process died before "
133+
"responding to the readiness probe") from e
134+
126135
yield async_engine_client
127136
finally:
128137
# Ensure rpc server process was terminated

vllm/entrypoints/openai/rpc/client.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from vllm.sampling_params import SamplingParams
1919
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
2020

21+
# Time to wait before checking it the server process is alive.
22+
SERVER_START_TIMEOUT_MS = 1000
23+
2124

2225
class AsyncEngineRPCClient:
2326

@@ -61,7 +64,16 @@ def socket(self):
6164
socket.connect(self.rpc_path)
6265
yield socket
6366
finally:
64-
socket.close()
67+
# linger == 0 means discard unsent messages
68+
# when the socket is closed. This is necessary
69+
# because otherwise self.context.destroy() will
70+
# wait for 30 seconds until unsent messages are
71+
# received, which is impossible if the server
72+
# crashed. In the absence of a server crash we
73+
# always expect a response before closing the
74+
# socket anyway.
75+
# Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
76+
socket.close(linger=0)
6577

6678
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
6779
expected_type: Any,
@@ -85,14 +97,19 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
8597

8698
return data
8799

88-
async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
89-
error_message: str):
100+
async def _send_one_way_rpc_request(self,
101+
request: RPC_REQUEST_TYPE,
102+
error_message: str,
103+
timeout: Optional[int] = None):
90104
"""Send one-way RPC request to trigger an action."""
91105
with self.socket() as socket:
92106
# Ping RPC Server with request.
93107
await socket.send(cloudpickle.dumps(request))
94108

95109
# Await acknowledgement from RPCServer.
110+
if timeout is not None and await socket.poll(timeout=timeout) == 0:
111+
raise TimeoutError(f"server didn't reply within {timeout} ms")
112+
96113
response = cloudpickle.loads(await socket.recv())
97114

98115
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
@@ -117,7 +134,8 @@ async def wait_for_server(self):
117134

118135
await self._send_one_way_rpc_request(
119136
request=RPCUtilityRequest.IS_SERVER_READY,
120-
error_message="Unable to start RPC Server.")
137+
error_message="Unable to start RPC Server.",
138+
timeout=SERVER_START_TIMEOUT_MS)
121139

122140
async def _get_model_config_rpc(self) -> ModelConfig:
123141
"""Get the ModelConfig object from the RPC Server"""

0 commit comments

Comments
 (0)