Skip to content

Commit 9a94ca4

Browse files
authored
[Bugfix] fix OpenAI API server startup with --disable-frontend-multiprocessing (#8537)
1 parent cfba685 commit 9a94ca4

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

tests/entrypoints/openai/test_basic.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from http import HTTPStatus
2+
from typing import List
23

34
import openai
45
import pytest
@@ -12,8 +13,44 @@
1213
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
1314

1415

16+
@pytest.fixture(scope='module')
17+
def server_args(request: pytest.FixtureRequest) -> List[str]:
18+
""" Provide extra arguments to the server via indirect parametrization
19+
20+
Usage:
21+
22+
>>> @pytest.mark.parametrize(
23+
>>> "server_args",
24+
>>> [
25+
>>> ["--disable-frontend-multiprocessing"],
26+
>>> [
27+
>>> "--model=NousResearch/Hermes-3-Llama-3.1-70B",
28+
>>> "--enable-auto-tool-choice",
29+
>>> ],
30+
>>> ],
31+
>>> indirect=True,
32+
>>> )
33+
>>> def test_foo(server, client):
34+
>>> ...
35+
36+
This will run `test_foo` twice with servers with:
37+
- `--disable-frontend-multiprocessing`
38+
- `--model=NousResearch/Hermes-3-Llama-3.1-70B --enable-auto-tool-choice`.
39+
40+
"""
41+
if not hasattr(request, "param"):
42+
return []
43+
44+
val = request.param
45+
46+
if isinstance(val, str):
47+
return [val]
48+
49+
return request.param
50+
51+
1552
@pytest.fixture(scope="module")
16-
def server():
53+
def server(server_args):
1754
args = [
1855
# use half precision for speed and memory savings in CI environment
1956
"--dtype",
@@ -23,6 +60,7 @@ def server():
2360
"--enforce-eager",
2461
"--max-num-seqs",
2562
"128",
63+
*server_args,
2664
]
2765

2866
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -35,6 +73,15 @@ async def client(server):
3573
yield async_client
3674

3775

76+
@pytest.mark.parametrize(
77+
"server_args",
78+
[
79+
pytest.param([], id="default-frontend-multiprocessing"),
80+
pytest.param(["--disable-frontend-multiprocessing"],
81+
id="disable-frontend-multiprocessing")
82+
],
83+
indirect=True,
84+
)
3885
@pytest.mark.asyncio
3986
async def test_show_version(client: openai.AsyncOpenAI):
4087
base_url = str(client.base_url)[:-3].strip("/")
@@ -45,6 +92,15 @@ async def test_show_version(client: openai.AsyncOpenAI):
4592
assert response.json() == {"version": VLLM_VERSION}
4693

4794

95+
@pytest.mark.parametrize(
96+
"server_args",
97+
[
98+
pytest.param([], id="default-frontend-multiprocessing"),
99+
pytest.param(["--disable-frontend-multiprocessing"],
100+
id="disable-frontend-multiprocessing")
101+
],
102+
indirect=True,
103+
)
48104
@pytest.mark.asyncio
49105
async def test_check_health(client: openai.AsyncOpenAI):
50106
base_url = str(client.base_url)[:-3].strip("/")

vllm/entrypoints/openai/api_server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,11 @@ async def run_server(args, **uvicorn_kwargs) -> None:
537537
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
538538
f"(chose from {{ {','.join(valide_tool_parses)} }})")
539539

540-
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
541-
temp_socket.bind(("", args.port))
540+
# workaround to make sure that we bind the port before the engine is set up.
541+
# This avoids race conditions with ray.
542+
# see https://github.com/vllm-project/vllm/issues/8204
543+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
544+
sock.bind(("", args.port))
542545

543546
def signal_handler(*_) -> None:
544547
# Interrupt server on sigterm while initializing
@@ -552,8 +555,6 @@ def signal_handler(*_) -> None:
552555
model_config = await engine_client.get_model_config()
553556
init_app_state(engine_client, model_config, app.state, args)
554557

555-
temp_socket.close()
556-
557558
shutdown_task = await serve_http(
558559
app,
559560
host=args.host,
@@ -564,6 +565,7 @@ def signal_handler(*_) -> None:
564565
ssl_certfile=args.ssl_certfile,
565566
ssl_ca_certs=args.ssl_ca_certs,
566567
ssl_cert_reqs=args.ssl_cert_reqs,
568+
fd=sock.fileno(),
567569
**uvicorn_kwargs,
568570
)
569571

0 commit comments

Comments
 (0)