|
1 | 1 | import asyncio
|
2 | 2 | import contextlib
|
3 |
| -import contextvars |
4 | 3 | import inspect
|
5 | 4 | import multiprocessing
|
6 | 5 | import os
|
|
58 | 57 | InvalidStateException,
|
59 | 58 | )
|
60 | 59 | from .helpers import SimpleStreamRedirector, StreamRedirector
|
61 |
| -from .scope import Scope, scope |
| 60 | +from .scope import Scope, current_scope, evolve_scope, scope |
62 | 61 |
|
63 | 62 | if PYDANTIC_V2:
|
64 | 63 | from .helpers import unwrap_pydantic_serialization_iterators
|
65 | 64 |
|
66 | 65 | _spawn = multiprocessing.get_context("spawn")
|
67 |
| -_tag_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( |
68 |
| - "tag", default=None |
69 |
| -) |
70 | 66 |
|
71 | 67 | _PublicEventType = Union[Done, Log, PredictionOutput, PredictionOutputType]
|
72 | 68 |
|
@@ -407,7 +403,7 @@ def __init__(
|
407 | 403 | self._cancelable = False
|
408 | 404 | self._max_concurrency = max_concurrency
|
409 | 405 |
|
410 |
| - # for synchronous predictors only! async predictors use _tag_var instead |
| 406 | + # for synchronous predictors only! async predictors use current_scope()._tag instead |
411 | 407 | self._sync_tag: Optional[str] = None
|
412 | 408 | self._has_async_predictor = is_async
|
413 | 409 |
|
@@ -483,10 +479,8 @@ def record_metric(self, name: str, value: Union[float, int]) -> None:
|
483 | 479 |
|
484 | 480 | @property
|
485 | 481 | def _current_tag(self) -> Optional[str]:
|
486 |
| - # if _tag_var is set, use that (only applies within _apredict()) |
487 |
| - tag = _tag_var.get() |
488 |
| - if tag: |
489 |
| - return tag |
| 482 | + if self._has_async_predictor: |
| 483 | + return current_scope()._tag |
490 | 484 | return self._sync_tag
|
491 | 485 |
|
492 | 486 | def _load_predictor(self) -> Optional[BasePredictor]:
|
@@ -687,9 +681,7 @@ async def _apredict(
|
687 | 681 | predict: Callable[..., Any],
|
688 | 682 | redirector: SimpleStreamRedirector,
|
689 | 683 | ) -> None:
|
690 |
| - _tag_var.set(tag) |
691 |
| - |
692 |
| - with self._handle_predict_error(redirector, tag=tag): |
| 684 | + with evolve_scope(tag=tag), self._handle_predict_error(redirector, tag=tag): |
693 | 685 | future_result = predict(**payload)
|
694 | 686 |
|
695 | 687 | if future_result:
|
|
0 commit comments