Skip to content

Commit 18a3e90

Browse files
committed
Capture standard output when loading the predictor
This commit fixes a bug where we failed to flush the StreamRedirector when catching an exception during the loading of the predictor module. We now use the existing `_handle_setup_error` function to ensure that the streams are flushed. I've kept the naming of the context manager the same because this all happens as part of the model setup. Two regression tests have been added to reproduce and verify that the issue has been fixed, both in normal and concurrent/async mode.
1 parent cca8874 commit 18a3e90

File tree

3 files changed

+45
-24
lines changed

3 files changed

+45
-24
lines changed

python/cog/server/worker.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,11 @@ def run(self) -> None:
430430
)
431431

432432
with scope(Scope(record_metric=self.record_metric)), redirector:
433-
self._predictor = self._load_predictor()
433+
with self._handle_setup_error(redirector):
434+
wait_for_env()
435+
self._predictor = load_predictor_from_ref(self._predictor_ref)
434436

435-
# If _load_predictor hasn't returned a predictor instance then
437+
# If load_predictor_from_ref hasn't returned a predictor instance then
436438
# it has sent a error Done event and we're done here.
437439
if not self._predictor:
438440
return
@@ -483,27 +485,6 @@ def _current_tag(self) -> Optional[str]:
483485
return _get_current_scope()._tag
484486
return self._sync_tag
485487

486-
def _load_predictor(self) -> Optional[BasePredictor]:
487-
done = Done()
488-
wait_for_env()
489-
try:
490-
return load_predictor_from_ref(self._predictor_ref)
491-
except Exception as e: # pylint: disable=broad-exception-caught
492-
traceback.print_exc()
493-
done.error = True
494-
done.error_detail = str(e)
495-
self._events.send(Envelope(event=done))
496-
except BaseException as e:
497-
# For SystemExit and friends we attempt to add some useful context
498-
# to the logs, but reraise to ensure the process dies.
499-
traceback.print_exc()
500-
done.error = True
501-
done.error_detail = str(e)
502-
self._events.send(Envelope(event=done))
503-
raise
504-
505-
return None
506-
507488
def _validate_predictor(
508489
self,
509490
redirector: Union[StreamRedirector, SimpleStreamRedirector],
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import sys
2+
3+
sys.stdout.write("writing to stdout at import time\n")
4+
sys.stderr.write("writing to stderr at import time\n")
5+
6+
import missing_module
7+
8+
9+
class Predictor:
10+
def setup(self):
11+
pass
12+
13+
def predict(self):
14+
print("did predict")

python/tests/server/test_worker.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def test_output(worker, payloads, output_generator, data):
487487
SETUP_LOGS_FIXTURES,
488488
indirect=["worker"],
489489
)
490-
def test_setup_logging(worker, expected_stdout, expected_stderr):
490+
def test_setup_logging(worker: Worker, expected_stdout, expected_stderr):
491491
"""
492492
We should get the logs we expect from predictors that generate logs during
493493
setup.
@@ -499,6 +499,32 @@ def test_setup_logging(worker, expected_stdout, expected_stderr):
499499
assert result.stderr == expected_stderr
500500

501501

502+
@uses_worker_configs(
503+
[
504+
WorkerConfig("import_err", setup=False),
505+
WorkerConfig("import_err", setup=False, min_python=(3, 11), is_async=True),
506+
]
507+
)
508+
def test_predictor_load_error_logging(worker: Worker):
509+
"""
510+
This test ensures that we capture standard output that occurrs when the predictor
511+
errors when it is loaded. Before setup or predict are even run.
512+
"""
513+
result = _process(worker, worker.setup, swallow_exceptions=True)
514+
515+
assert result.done.error
516+
assert result.done.error_detail == "No module named 'missing_module'"
517+
518+
assert result.stdout == "writing to stdout at import time\n"
519+
stderr_lines = result.stderr.splitlines(keepends=True)
520+
assert stderr_lines[0] == "writing to stderr at import time\n"
521+
522+
assert "python/tests/server/fixtures/import_err.py" in stderr_lines[-3]
523+
assert "line 6" in stderr_lines[-3]
524+
assert "import missing_module" in stderr_lines[-2]
525+
assert stderr_lines[-1] == "ModuleNotFoundError: No module named 'missing_module'\n"
526+
527+
502528
@pytest.mark.parametrize(
503529
"worker,expected_stdout,expected_stderr",
504530
PREDICT_LOGS_FIXTURES,

0 commit comments

Comments
 (0)