Skip to content

Fix hanging #126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ def __init__(self, *, in_process: bool = False, path: str | None = None) -> None
self._tokenizers: dict[str, "PreTrainedTokenizerBase"] = {}
self._wandb_runs: dict[str, Run] = {}

def __enter__(self):
return self

def __exit__(self, *excinfo):
self.close()

def close(self):
for _, service in self._services.items():
close_method = getattr(service, "close", None)
if callable(close_method):
close_method()

async def register(
self,
model: Model,
Expand Down
67 changes: 58 additions & 9 deletions src/mp_actors/move.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tblib import pickling_support
from typing import Any, AsyncGenerator, cast, TypeVar
import uuid
from concurrent.futures import ThreadPoolExecutor

from .traceback import streamline_tracebacks

Expand All @@ -20,6 +21,8 @@

T = TypeVar("T")

# Special ID to signal shutdown
_SHUTDOWN_ID = "__shutdown__"

def move_to_child_process(
obj: T, log_file: str | None = None, process_name: str | None = None
Expand Down Expand Up @@ -72,15 +75,24 @@ def __init__(
args=(obj, self._requests, self._responses, log_file, process_name),
)
self._process.start()
# dedicated executor for queue.get calls
self._executor = ThreadPoolExecutor()
self._futures: dict[str, asyncio.Future] = {}
self._handle_responses_task = asyncio.create_task(self._handle_responses())

async def _handle_responses(self) -> None:
loop = asyncio.get_event_loop()
while True:
response: Response = await asyncio.get_event_loop().run_in_executor(
None, self._responses.get
response: Response = await loop.run_in_executor(
self._executor, self._responses.get
)
future = self._futures.pop(response.id)
# check for shutdown signal
if response.id == _SHUTDOWN_ID:
break
# normal processing
future = self._futures.pop(response.id, None)
if future is None:
continue
if response.exception:
future.set_exception(response.exception)
else:
Expand Down Expand Up @@ -136,18 +148,52 @@ async def async_method_wrapper(*args: Any, **kwargs: Any) -> Any:
# Return a regular function wrapper
@streamline_tracebacks()
def method_wrapper(*args: Any, **kwargs: Any) -> Any:
return asyncio.run(get_response(args, kwargs))
loop = asyncio.get_event_loop()
if loop.is_running():
fut = asyncio.run_coroutine_threadsafe(get_response(args, kwargs), loop)
return fut.result()
else:
return asyncio.run(get_response(args, kwargs))

return method_wrapper
else:
# For non-callable attributes, get them directly
return asyncio.run(get_response(tuple(), dict()))
return asyncio.run(get_response(tuple(), {}))

def close(self):
# signal the response loop to exit
self._responses.put_nowait(Response(_SHUTDOWN_ID, None, None))
# wait for the handler to finish
if hasattr(self, "_handle_responses_task"):
# give it a moment to break
try:
asyncio.get_event_loop().run_until_complete(self._handle_responses_task)
except Exception:
pass

# terminate child process and force kill if needed
if hasattr(self, "_process"):
self._process.terminate()
try:
self._process.join(timeout=1)
except Exception:
pass
if self._process.is_alive():
# Python 3.7+: force kill
try:
self._process.kill()
except AttributeError:
# fallback: os.kill
os.kill(self._process.pid, 9)
self._process.join()

# shutdown executor cleanly
self._executor.shutdown(wait=True)

def __del__(self) -> None:
self._handle_responses_task.cancel()
self._process.terminate()
# close and cancel queue feeder threads
self._responses.close()
self._responses.cancel_join_thread()
self._requests.close()
self._requests.cancel_join_thread()


def _target(
Expand Down Expand Up @@ -199,6 +245,9 @@ async def _handle_request(
else:
result = result_or_callable
response = Response(request.id, result, None)
except StopAsyncIteration:
generators.pop(request.id, None)
return
except Exception as e:
pickling_support.install(e)
response = Response(request.id, None, e)
Expand Down