Skip to content

Commit 5eb31ff

Browse files
Support async predictors (✨ again ✨) (#2025)
* Revert "Revert "Support async predictors (#2010)" (#2022)" This reverts commit 8333a83. * Use anync stream redirector in setup so that the sync stream redirector context is only entered once, as this is a known source of problems associated with stdout/stderr orphaning. * Do not assert that writes from C are captured during setup * Do not wrap empty data in Log events * Exclude invalid output paths from infrastructure errors (#2030) * Exclude invalid output paths from infrastructure errors Closes PLAT-380 * Fix words given pluralization Co-authored-by: F <[email protected]> Signed-off-by: Dan Buch <[email protected]> --------- Signed-off-by: Dan Buch <[email protected]> Co-authored-by: F <[email protected]> --------- Signed-off-by: Dan Buch <[email protected]> Co-authored-by: F <[email protected]>
1 parent f8e3461 commit 5eb31ff

File tree

9 files changed

+415
-49
lines changed

9 files changed

+415
-49
lines changed

python/cog/server/connection.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import asyncio
2+
import multiprocessing
3+
from multiprocessing.connection import Connection
4+
from typing import Any, Optional
5+
6+
from typing_extensions import Buffer
7+
8+
_spawn = multiprocessing.get_context("spawn")
9+
10+
11+
class AsyncConnection:
12+
def __init__(self, connection: Connection) -> None:
13+
self._connection = connection
14+
self._event = asyncio.Event()
15+
loop = asyncio.get_event_loop()
16+
loop.add_reader(self._connection.fileno(), self._event.set)
17+
18+
def send(self, obj: Any) -> None:
19+
"""Send a (picklable) object"""
20+
21+
self._connection.send(obj)
22+
23+
async def _wait_for_input(self) -> None:
24+
"""Wait until there is an input available to be read"""
25+
26+
while not self._connection.poll():
27+
await self._event.wait()
28+
self._event.clear()
29+
30+
async def recv(self) -> Any:
31+
"""Receive a (picklable) object"""
32+
33+
await self._wait_for_input()
34+
return self._connection.recv()
35+
36+
def fileno(self) -> int:
37+
"""File descriptor or handle of the connection"""
38+
return self._connection.fileno()
39+
40+
def close(self) -> None:
41+
"""Close the connection"""
42+
self._connection.close()
43+
44+
async def poll(self, timeout: float = 0.0) -> bool:
45+
"""Whether there is an input available to be read"""
46+
47+
if self._connection.poll():
48+
return True
49+
50+
try:
51+
await asyncio.wait_for(self._wait_for_input(), timeout=timeout)
52+
except asyncio.TimeoutError:
53+
return False
54+
return self._connection.poll()
55+
56+
def send_bytes(
57+
self, buf: Buffer, offset: int = 0, size: Optional[int] = None
58+
) -> None:
59+
"""Send the bytes data from a bytes-like object"""
60+
61+
self._connection.send_bytes(buf, offset, size)
62+
63+
async def recv_bytes(self, maxlength: Optional[int] = None) -> bytes:
64+
"""
65+
Receive bytes data as a bytes object.
66+
"""
67+
68+
await self._wait_for_input()
69+
return self._connection.recv_bytes(maxlength)
70+
71+
async def recv_bytes_into(self, buf: Buffer, offset: int = 0) -> int:
72+
"""
73+
Receive bytes data into a writeable bytes-like object.
74+
Return the number of bytes read.
75+
"""
76+
77+
await self._wait_for_input()
78+
return self._connection.recv_bytes_into(buf, offset)
79+
80+
81+
class LockedConnection:
82+
def __init__(self, connection: Connection) -> None:
83+
self.connection = connection
84+
self._lock = _spawn.Lock()
85+
86+
def send(self, obj: Any) -> None:
87+
with self._lock:
88+
self.connection.send(obj)
89+
90+
def recv(self) -> Any:
91+
return self.connection.recv()

python/cog/server/eventtypes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55

66
# From worker parent process
77
#
8+
@define
9+
class Cancel:
10+
# TODO: identify which prediction!
11+
pass
12+
13+
814
@define
915
class PredictionInput:
1016
payload: Dict[str, Any]

python/cog/server/helpers.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import threading
1111
import uuid
1212
from types import TracebackType
13-
from typing import Any, Callable, Dict, List, Sequence, TextIO, Union
13+
from typing import Any, BinaryIO, Callable, Dict, List, Sequence, TextIO, Union
1414

1515
import pydantic
1616
from typing_extensions import Self
@@ -19,6 +19,45 @@
1919
from .errors import CogRuntimeError, CogTimeoutError
2020

2121

22+
class _SimpleStreamWrapper(io.TextIOWrapper):
23+
"""
24+
_SimpleStreamWrapper wraps a binary I/O buffer and provides a TextIOWrapper
25+
interface (primarily write and flush methods) which call a provided
26+
callback function instead of (or, if `tee` is True, in addition to) writing
27+
to the underlying buffer.
28+
"""
29+
30+
def __init__(
31+
self,
32+
buffer: BinaryIO,
33+
callback: Callable[[str, str], None],
34+
tee: bool = False,
35+
) -> None:
36+
super().__init__(buffer, line_buffering=True)
37+
38+
self._callback = callback
39+
self._tee = tee
40+
self._buffer = []
41+
42+
def write(self, s: str) -> int:
43+
length = len(s)
44+
self._buffer.append(s)
45+
if self._tee:
46+
super().write(s)
47+
else:
48+
# If we're not teeing, we have to handle automatic flush on
49+
# newline. When `tee` is true, this is handled by the write method.
50+
if "\n" in s or "\r" in s:
51+
self.flush()
52+
return length
53+
54+
def flush(self) -> None:
55+
self._callback(self.name, "".join(self._buffer))
56+
self._buffer.clear()
57+
if self._tee:
58+
super().flush()
59+
60+
2261
class _StreamWrapper:
2362
def __init__(self, name: str, stream: TextIO) -> None:
2463
self.name = name
@@ -86,6 +125,66 @@ def original(self) -> TextIO:
86125
return self._original_fp
87126

88127

128+
if sys.version_info < (3, 9):
129+
130+
class _AsyncStreamRedirectorBase(contextlib.AbstractContextManager):
131+
pass
132+
else:
133+
134+
class _AsyncStreamRedirectorBase(
135+
contextlib.AbstractContextManager["AsyncStreamRedirector"]
136+
):
137+
pass
138+
139+
140+
class AsyncStreamRedirector(_AsyncStreamRedirectorBase):
141+
"""
142+
AsyncStreamRedirector is a context manager that redirects I/O streams to a
143+
callback function. If `tee` is True, it also writes output to the original
144+
streams.
145+
146+
Unlike StreamRedirector, the underlying stream file descriptors are not
147+
modified, which means that only stream writes from Python code will be
148+
captured. Writes from native code will not be captured.
149+
150+
Unlike StreamRedirector, the streams redirected cannot be configured. The
151+
context manager is only able to redirect STDOUT and STDERR.
152+
"""
153+
154+
def __init__(
155+
self,
156+
callback: Callable[[str, str], None],
157+
tee: bool = False,
158+
) -> None:
159+
self._callback = callback
160+
self._tee = tee
161+
162+
stdout_wrapper = _SimpleStreamWrapper(sys.stdout.buffer, callback, tee)
163+
stderr_wrapper = _SimpleStreamWrapper(sys.stderr.buffer, callback, tee)
164+
self._stdout_ctx = contextlib.redirect_stdout(stdout_wrapper)
165+
self._stderr_ctx = contextlib.redirect_stderr(stderr_wrapper)
166+
167+
def __enter__(self) -> Self:
168+
self._stdout_ctx.__enter__()
169+
self._stderr_ctx.__enter__()
170+
return self
171+
172+
def __exit__(
173+
self,
174+
exc_type: type[BaseException] | None,
175+
exc_value: BaseException | None,
176+
traceback: TracebackType | None,
177+
) -> None:
178+
self._stdout_ctx.__exit__(exc_type, exc_value, traceback)
179+
self._stderr_ctx.__exit__(exc_type, exc_value, traceback)
180+
181+
def drain(self, timeout: float = 0.0) -> None:
182+
# Draining isn't complicated for AsyncStreamRedirector, since we're not
183+
# moving data between threads. We just need to flush the streams.
184+
sys.stdout.flush()
185+
sys.stderr.flush()
186+
187+
89188
if sys.version_info < (3, 9):
90189

91190
class _StreamRedirectorBase(contextlib.AbstractContextManager):

python/cog/server/runner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,17 @@ def _upload_files(self, output: Any) -> Any:
408408
try:
409409
# TODO: clean up output files
410410
return self._file_uploader(output)
411+
except (FileNotFoundError, NotADirectoryError):
412+
# These error cases indicate that an output path returned by a prediction does
413+
# not actually exist, so there is no way for us to even attempt to upload it.
414+
# The error is re-raised without wrapping because this is not considered an
415+
# "infrastructure error", such as happens during an upload of a file that
416+
# **does** exist.
417+
raise
411418
except Exception as error: # pylint: disable=broad-exception-caught
412-
# If something goes wrong uploading a file, it's irrecoverable.
413-
# The re-raised exception will be caught and cause the prediction
414-
# to be failed, with a useful error message.
419+
# Any other errors that occur during file upload are irrecoverable and
420+
# considered "infrastructure errors" because there is a high likelihood that
421+
# the error happened in a layer that is outside the control of the model.
415422
raise FileUploadError("Got error trying to upload output files") from error
416423

417424
def _handle_done(self, f: "Future[Done]") -> None:

0 commit comments

Comments
 (0)