Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ def __init__(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)

tokenizer_group = self.get_tokenizer_group()

# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)

# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
Expand All @@ -346,10 +353,10 @@ def __init__(
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
get_tokenizer_for_seq,
),
))

Expand Down Expand Up @@ -481,10 +488,6 @@ def get_tokenizer(
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)

def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)

def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
Expand Down
49 changes: 28 additions & 21 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,35 @@ async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

server = await build_server(
args,
llm_engine,
**uvicorn_kwargs,
)

loop = asyncio.get_running_loop()

server_task = loop.create_task(server.serve())

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()

loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)

try:
await server_task
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()
server = await build_server(
args,
llm_engine,
**uvicorn_kwargs,
)

loop = asyncio.get_running_loop()

server_task = loop.create_task(server.serve())

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()

loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)

try:
await server_task
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()
finally:
# Clean up globals
for var in ("openai_serving_chat", "openai_serving_completion",
"openai_serving_embedding", "openai_serving_tokenization",
"engine_args", "engine"):
globals().pop(var, None)
Comment on lines +338 to +342
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does del work? globals().pop is too hacky.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao we'd have to check and del each one individually so it would be a lot more code. I want this to work whether or not each is defined, in case some error occurs after setting some and not others.

The way globals are used here is already hacky imo and I think we'll clean it up later. I wanted to keep this change as simple as possible as there are overlapping changes in #6883 which will be merged very soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. this is a minor concern, and you can skip it if it is difficult to solve.

the most important part is still make the ci pass 🙏



if __name__ == "__main__":
Expand Down
18 changes: 0 additions & 18 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import asyncio
import os
import signal
import threading
import weakref
from functools import partial
from typing import Any, List, Optional

Expand Down Expand Up @@ -118,23 +115,8 @@ def _init_executor(self) -> None:
self.non_driver_workers.append(worker)

self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()

# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well

# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)

def shutdown(signum, frame):
if executor := ref():
executor.shutdown()

if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)

self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")
Expand Down
47 changes: 36 additions & 11 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import traceback
import uuid
import weakref
from dataclasses import dataclass
from multiprocessing import Queue
from multiprocessing.connection import wait
Expand Down Expand Up @@ -76,7 +77,7 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""

def __init__(self) -> None:
super().__init__(daemon=True)
super().__init__(daemon=False)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}

Expand All @@ -100,27 +101,51 @@ class WorkerMonitor(threading.Thread):

def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler):
super().__init__(daemon=True)
super().__init__(daemon=False)
self.workers = workers
self.result_handler = result_handler
self._close = False

# Set up a handler to ensure that the threads and worker
# processes are shut down in the case the interpreter exits due
# to an unhandled exception. GC does not appear to be reliable
# for this.
ref = weakref.ref(self)
old_handler = sys.excepthook

def handler(*args):
old_handler(*args)
if (monitor := ref()) is not None:
monitor.close()

sys.excepthook = handler

def run(self) -> None:
# We are responsible for starting the result handler thread
self.result_handler.start()

# Blocks until any worker exits
dead_sentinels = wait([w.process.sentinel for w in self.workers])
if not self._close:
self._close = True

# Kill / cleanup all workers
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid,
process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")

# Cleanup any remaining workers
logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
Expand Down