Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,19 @@ def _wait(
)


class LockedConn:
def __init__(self, conn: Connection) -> None:
self.conn = conn
self._lock = _spawn.Lock()

def send(self, obj: Any) -> None:
with self._lock:
self.conn.send(obj)

def recv(self) -> Any:
return self.conn.recv()


class _ChildWorker(_spawn.Process): # type: ignore
def __init__(
self,
Expand All @@ -150,10 +163,9 @@ def __init__(
) -> None:
self._predictor_ref = predictor_ref
self._predictor: Optional[BasePredictor] = None
self._events = events
self._events = LockedConn(events)
self._tee_output = tee_output
self._cancelable = False
self._events_lock = _spawn.Lock()

super().__init__()

Expand Down Expand Up @@ -201,8 +213,7 @@ def _setup(self) -> None:
raise
finally:
self._stream_redirector.drain()
with self._events_lock:
self._events.send(done)
self._events.send(done)

def _loop(self) -> None:
while True:
Expand All @@ -223,18 +234,13 @@ def _predict(self, payload: Dict[str, Any]) -> None:
result = predict(**payload)

if result:
with self._events_lock:
if isinstance(result, types.GeneratorType):
self._events.send(PredictionOutputType(multi=True))
for r in result:
self._events.send(
PredictionOutput(payload=make_encodeable(r))
)
else:
self._events.send(PredictionOutputType(multi=False))
self._events.send(
PredictionOutput(payload=make_encodeable(result))
)
if isinstance(result, types.GeneratorType):
self._events.send(PredictionOutputType(multi=True))
for r in result:
self._events.send(PredictionOutput(payload=make_encodeable(r)))
else:
self._events.send(PredictionOutputType(multi=False))
self._events.send(PredictionOutput(payload=make_encodeable(result)))
except CancelationException:
done.canceled = True
except Exception as e:
Expand All @@ -244,8 +250,7 @@ def _predict(self, payload: Dict[str, Any]) -> None:
finally:
self._cancelable = False
self._stream_redirector.drain()
with self._events_lock:
self._events.send(done)
self._events.send(done)

def _signal_handler(self, signum: int, frame: Optional[types.FrameType]) -> None:
if signum == signal.SIGUSR1 and self._cancelable:
Expand All @@ -257,5 +262,4 @@ def _stream_write_hook(
if self._tee_output:
original_stream.write(data)
original_stream.flush()
with self._events_lock:
self._events.send(Log(data, source=stream_name))
self._events.send(Log(data, source=stream_name))