Skip to content

Commit 1007849

Browse files
committed
Move _tag_var to Scope
This moves _tag_var to Scope, so that we have a single ContextVar rather than multiple.
1 parent d04f127 commit 1007849

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

python/cog/server/scope.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import warnings
22
from contextlib import contextmanager
33
from contextvars import ContextVar
4-
from typing import Callable, Generator, Optional, Union
4+
from typing import Any, Callable, Generator, Optional, Union
55

6-
from attrs import frozen
6+
from attrs import evolve, frozen
77

88
from ..types import ExperimentalFeatureWarning
99

1010

1111
@frozen
1212
class Scope:
1313
record_metric: Callable[[str, Union[float, int]], None]
14+
_tag: Optional[str] = None
1415

1516

1617
_current_scope: ContextVar[Optional[Scope]] = ContextVar("scope", default=None)
@@ -37,6 +38,16 @@ def scope(sc: Scope) -> Generator[None, None, None]:
3738
_current_scope.reset(s)
3839

3940

41+
@contextmanager
42+
def evolve_scope(**kwargs: Any) -> Generator[None, None, None]:
43+
new_scope = evolve(current_scope(), **kwargs)
44+
s = _current_scope.set(new_scope)
45+
try:
46+
yield
47+
finally:
48+
_current_scope.reset(s)
49+
50+
4051
def emit_metric(name: str, value: Union[float, int]) -> None:
4152
"""
4253
DEPRECATED: This function will be removed in a future version of cog.

python/cog/server/worker.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import contextlib
3-
import contextvars
43
import inspect
54
import multiprocessing
65
import os
@@ -58,15 +57,12 @@
5857
InvalidStateException,
5958
)
6059
from .helpers import SimpleStreamRedirector, StreamRedirector
61-
from .scope import Scope, scope
60+
from .scope import Scope, current_scope, evolve_scope, scope
6261

6362
if PYDANTIC_V2:
6463
from .helpers import unwrap_pydantic_serialization_iterators
6564

6665
_spawn = multiprocessing.get_context("spawn")
67-
_tag_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
68-
"tag", default=None
69-
)
7066

7167
_PublicEventType = Union[Done, Log, PredictionOutput, PredictionOutputType]
7268

@@ -407,7 +403,7 @@ def __init__(
407403
self._cancelable = False
408404
self._max_concurrency = max_concurrency
409405

410-
# for synchronous predictors only! async predictors use _tag_var instead
406+
# for synchronous predictors only! async predictors use current_scope()._tag instead
411407
self._sync_tag: Optional[str] = None
412408
self._has_async_predictor = is_async
413409

@@ -483,10 +479,8 @@ def record_metric(self, name: str, value: Union[float, int]) -> None:
483479

484480
@property
485481
def _current_tag(self) -> Optional[str]:
486-
# if _tag_var is set, use that (only applies within _apredict())
487-
tag = _tag_var.get()
488-
if tag:
489-
return tag
482+
if self._has_async_predictor:
483+
return current_scope()._tag
490484
return self._sync_tag
491485

492486
def _load_predictor(self) -> Optional[BasePredictor]:
@@ -687,9 +681,7 @@ async def _apredict(
687681
predict: Callable[..., Any],
688682
redirector: SimpleStreamRedirector,
689683
) -> None:
690-
_tag_var.set(tag)
691-
692-
with self._handle_predict_error(redirector, tag=tag):
684+
with evolve_scope(tag=tag), self._handle_predict_error(redirector, tag=tag):
693685
future_result = predict(**payload)
694686

695687
if future_result:

0 commit comments

Comments
 (0)