Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit c2fb882

Browse files
author
Sean Quah
committed
Update delay_cancellation to accept any awaitable
This will mainly be useful when dealing with module callbacks, which are all typed as returning `Awaitable`s instead of coroutines or `Deferred`s. Signed-off-by: Sean Quah <[email protected]>
1 parent e3a49f4 commit c2fb882

File tree

4 files changed

+71
-13
lines changed

4 files changed

+71
-13
lines changed

changelog.d/12468.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update `delay_cancellation` to accept any awaitable, rather than just `Deferred`s.

synapse/storage/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ async def _runInteraction() -> R:
794794
# We also wait until everything above is done before releasing the
795795
# `CancelledError`, so that logging contexts won't get used after they have been
796796
# finished.
797-
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
797+
return await delay_cancellation(_runInteraction())
798798

799799
async def runWithConnection(
800800
self,

synapse/util/async_helpers.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Awaitable,
2626
Callable,
2727
Collection,
28+
Coroutine,
2829
Dict,
2930
Generic,
3031
Hashable,
@@ -701,27 +702,54 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
701702
return new_deferred
702703

703704

704-
def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
705-
"""Delay cancellation of a `Deferred` until it resolves.
705+
@overload
706+
def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
707+
...
708+
709+
710+
@overload
711+
def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
712+
...
713+
714+
715+
@overload
716+
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
717+
...
718+
719+
720+
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
721+
"""Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
706722
707723
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
708-
resolve with a `CancelledError` until the original `Deferred` resolves.
724+
resolve with a `CancelledError` until the original awaitable resolves.
709725
710726
Args:
711-
deferred: The `Deferred` to protect against cancellation. May optionally follow
712-
the Synapse logcontext rules.
727+
deferred: The coroutine or `Deferred` to protect against cancellation. May
728+
optionally follow the Synapse logcontext rules.
713729
714730
Returns:
715-
A new `Deferred`, which will contain the result of the original `Deferred`.
716-
The new `Deferred` will not propagate cancellation through to the original.
717-
When cancelled, the new `Deferred` will wait until the original `Deferred`
718-
resolves before failing with a `CancelledError`.
731+
A new `Deferred`, which will contain the result of the original coroutine or
732+
`Deferred`. The new `Deferred` will not propagate cancellation through to the
733+
original coroutine or `Deferred`.
719734
720-
The new `Deferred` will follow the Synapse logcontext rules if `deferred`
735+
When cancelled, the new `Deferred` will wait until the original coroutine or
736+
`Deferred` resolves before failing with a `CancelledError`.
737+
738+
The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
721739
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
722740
wrapped with `make_deferred_yieldable`.
723741
"""
724742

743+
# First, convert the awaitable into a `Deferred`.
744+
if isinstance(awaitable, defer.Deferred):
745+
deferred = awaitable
746+
elif isinstance(awaitable, Coroutine):
747+
deferred = defer.ensureDeferred(awaitable)
748+
else:
749+
# We have no idea what to do with this awaitable.
750+
# Let the caller `await` it normally.
751+
return awaitable
752+
725753
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
726754
# before the new deferred is cancelled, we `pause` it to stop the cancellation
727755
# propagating. we then `unpause` it once the wrapped deferred completes, to

tests/util/test_async_helpers.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def test_cancellation(self):
382382
class DelayCancellationTests(TestCase):
383383
"""Tests for the `delay_cancellation` function."""
384384

385-
def test_cancellation(self):
385+
def test_deferred_cancellation(self):
386386
"""Test that cancellation of the new `Deferred` waits for the original."""
387387
deferred: "Deferred[str]" = Deferred()
388388
wrapper_deferred = delay_cancellation(deferred)
@@ -403,6 +403,35 @@ def test_cancellation(self):
403403
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
404404
self.failureResultOf(wrapper_deferred, CancelledError)
405405

406+
def test_coroutine_cancellation(self):
407+
"""Test that cancellation of the new `Deferred` waits for the original."""
408+
blocking_deferred: "Deferred[None]" = Deferred()
409+
completion_deferred: "Deferred[None]" = Deferred()
410+
411+
async def task():
412+
await blocking_deferred
413+
completion_deferred.callback(None)
414+
# Raise an exception. Twisted should consume it, otherwise unwanted
415+
# tracebacks will be printed in logs.
416+
raise ValueError("abc")
417+
418+
wrapper_deferred = delay_cancellation(task())
419+
420+
# Cancel the new `Deferred`.
421+
wrapper_deferred.cancel()
422+
self.assertNoResult(wrapper_deferred)
423+
self.assertFalse(
424+
blocking_deferred.called, "Cancellation was propagated too deep"
425+
)
426+
self.assertFalse(completion_deferred.called)
427+
428+
# Unblock the task.
429+
blocking_deferred.callback(None)
430+
self.assertTrue(completion_deferred.called)
431+
432+
# Now that the original coroutine has failed, we should get a `CancelledError`.
433+
self.failureResultOf(wrapper_deferred, CancelledError)
434+
406435
def test_suppresses_second_cancellation(self):
407436
"""Test that a second cancellation is suppressed.
408437
@@ -451,7 +480,7 @@ async def inner():
451480
async def outer():
452481
with LoggingContext("c") as c:
453482
try:
454-
await delay_cancellation(defer.ensureDeferred(inner()))
483+
await delay_cancellation(inner())
455484
self.fail("`CancelledError` was not raised")
456485
except CancelledError:
457486
self.assertEqual(c, current_context())

0 commit comments

Comments
 (0)