Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 6121056

Browse files
authored
Handle cancellation in DatabasePool.runInteraction() (#12199)
To handle cancellation, we ensure that `after_callback`s and `exception_callback`s are always run, since the transaction will complete on another thread regardless of cancellation. We also wait until everything is done before releasing the `CancelledError`, so that logging contexts won't get used after they have been finished. Signed-off-by: Sean Quah <[email protected]>
1 parent fc9bd62 commit 6121056

File tree

3 files changed

+96
-24
lines changed

3 files changed

+96
-24
lines changed

changelog.d/12199.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Handle cancellation in `DatabasePool.runInteraction()`.

synapse/storage/database.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from typing_extensions import Literal
4242

4343
from twisted.enterprise import adbapi
44+
from twisted.internet import defer
4445

4546
from synapse.api.errors import StoreError
4647
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@
5556
from synapse.storage.background_updates import BackgroundUpdater
5657
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
5758
from synapse.storage.types import Connection, Cursor
59+
from synapse.util.async_helpers import delay_cancellation
5860
from synapse.util.iterutils import batch_iter
5961

6062
if TYPE_CHECKING:
@@ -732,34 +734,45 @@ async def runInteraction(
732734
Returns:
733735
The result of func
734736
"""
735-
after_callbacks: List[_CallbackListEntry] = []
736-
exception_callbacks: List[_CallbackListEntry] = []
737737

738-
if not current_context():
739-
logger.warning("Starting db txn '%s' from sentinel context", desc)
738+
async def _runInteraction() -> R:
739+
after_callbacks: List[_CallbackListEntry] = []
740+
exception_callbacks: List[_CallbackListEntry] = []
740741

741-
try:
742-
with opentracing.start_active_span(f"db.{desc}"):
743-
result = await self.runWithConnection(
744-
self.new_transaction,
745-
desc,
746-
after_callbacks,
747-
exception_callbacks,
748-
func,
749-
*args,
750-
db_autocommit=db_autocommit,
751-
isolation_level=isolation_level,
752-
**kwargs,
753-
)
742+
if not current_context():
743+
logger.warning("Starting db txn '%s' from sentinel context", desc)
754744

755-
for after_callback, after_args, after_kwargs in after_callbacks:
756-
after_callback(*after_args, **after_kwargs)
757-
except Exception:
758-
for after_callback, after_args, after_kwargs in exception_callbacks:
759-
after_callback(*after_args, **after_kwargs)
760-
raise
745+
try:
746+
with opentracing.start_active_span(f"db.{desc}"):
747+
result = await self.runWithConnection(
748+
self.new_transaction,
749+
desc,
750+
after_callbacks,
751+
exception_callbacks,
752+
func,
753+
*args,
754+
db_autocommit=db_autocommit,
755+
isolation_level=isolation_level,
756+
**kwargs,
757+
)
761758

762-
return cast(R, result)
759+
for after_callback, after_args, after_kwargs in after_callbacks:
760+
after_callback(*after_args, **after_kwargs)
761+
762+
return cast(R, result)
763+
except Exception:
764+
for after_callback, after_args, after_kwargs in exception_callbacks:
765+
after_callback(*after_args, **after_kwargs)
766+
raise
767+
768+
# To handle cancellation, we ensure that `after_callback`s and
769+
# `exception_callback`s are always run, since the transaction will complete
770+
# on another thread regardless of cancellation.
771+
#
772+
# We also wait until everything above is done before releasing the
773+
# `CancelledError`, so that logging contexts won't get used after they have been
774+
# finished.
775+
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
763776

764777
async def runWithConnection(
765778
self,

tests/storage/test_database.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from typing import Callable, Tuple
1616
from unittest.mock import Mock, call
1717

18+
from twisted.internet import defer
19+
from twisted.internet.defer import CancelledError, Deferred
1820
from twisted.test.proto_helpers import MemoryReactor
1921

2022
from synapse.server import HomeServer
@@ -124,3 +126,59 @@ def test_successful_retry(self) -> None:
124126
)
125127
self.assertEqual(after_callback.call_count, 2) # no additional calls
126128
exception_callback.assert_not_called()
129+
130+
131+
class CancellationTestCase(unittest.HomeserverTestCase):
132+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
133+
self.store = hs.get_datastores().main
134+
self.db_pool: DatabasePool = self.store.db_pool
135+
136+
def test_after_callback(self) -> None:
137+
"""Test that the after callback is called when a transaction succeeds."""
138+
d: "Deferred[None]"
139+
after_callback = Mock()
140+
exception_callback = Mock()
141+
142+
def _test_txn(txn: LoggingTransaction) -> None:
143+
txn.call_after(after_callback, 123, 456, extra=789)
144+
txn.call_on_exception(exception_callback, 987, 654, extra=321)
145+
d.cancel()
146+
147+
d = defer.ensureDeferred(
148+
self.db_pool.runInteraction("test_transaction", _test_txn)
149+
)
150+
self.get_failure(d, CancelledError)
151+
152+
after_callback.assert_called_once_with(123, 456, extra=789)
153+
exception_callback.assert_not_called()
154+
155+
def test_exception_callback(self) -> None:
156+
"""Test that the exception callback is called when a transaction fails."""
157+
d: "Deferred[None]"
158+
after_callback = Mock()
159+
exception_callback = Mock()
160+
161+
def _test_txn(txn: LoggingTransaction) -> None:
162+
txn.call_after(after_callback, 123, 456, extra=789)
163+
txn.call_on_exception(exception_callback, 987, 654, extra=321)
164+
d.cancel()
165+
# Simulate a retryable failure on every attempt.
166+
raise self.db_pool.engine.module.OperationalError()
167+
168+
d = defer.ensureDeferred(
169+
self.db_pool.runInteraction("test_transaction", _test_txn)
170+
)
171+
self.get_failure(d, CancelledError)
172+
173+
after_callback.assert_not_called()
174+
exception_callback.assert_has_calls(
175+
[
176+
call(987, 654, extra=321),
177+
call(987, 654, extra=321),
178+
call(987, 654, extra=321),
179+
call(987, 654, extra=321),
180+
call(987, 654, extra=321),
181+
call(987, 654, extra=321),
182+
]
183+
)
184+
self.assertEqual(exception_callback.call_count, 6) # no additional calls

0 commit comments

Comments
 (0)