31
31
List ,
32
32
Optional ,
33
33
Tuple ,
34
+ Type ,
34
35
TypeVar ,
35
36
cast ,
36
37
overload ,
41
42
from typing_extensions import Concatenate , Literal , ParamSpec
42
43
43
44
from twisted .enterprise import adbapi
45
+ from twisted .internet .interfaces import IReactorCore
44
46
45
47
from synapse .api .errors import StoreError
46
48
from synapse .config .database import DatabaseConnectionConfig
92
94
93
95
94
96
def make_pool (
95
- reactor , db_config : DatabaseConnectionConfig , engine : BaseDatabaseEngine
97
+ reactor : IReactorCore ,
98
+ db_config : DatabaseConnectionConfig ,
99
+ engine : BaseDatabaseEngine ,
96
100
) -> adbapi .ConnectionPool :
97
101
"""Get the connection pool for the database."""
98
102
@@ -101,7 +105,7 @@ def make_pool(
101
105
db_args = dict (db_config .config .get ("args" , {}))
102
106
db_args .setdefault ("cp_reconnect" , True )
103
107
104
- def _on_new_connection (conn ) :
108
+ def _on_new_connection (conn : Connection ) -> None :
105
109
# Ensure we have a logging context so we can correctly track queries,
106
110
# etc.
107
111
with LoggingContext ("db.on_new_connection" ):
@@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
157
161
default_txn_name : str
158
162
159
163
def cursor (
160
- self , * , txn_name = None , after_callbacks = None , exception_callbacks = None
164
+ self ,
165
+ * ,
166
+ txn_name : Optional [str ] = None ,
167
+ after_callbacks : Optional [List ["_CallbackListEntry" ]] = None ,
168
+ exception_callbacks : Optional [List ["_CallbackListEntry" ]] = None ,
161
169
) -> "LoggingTransaction" :
162
170
if not txn_name :
163
171
txn_name = self .default_txn_name
@@ -183,11 +191,16 @@ def __enter__(self) -> "LoggingDatabaseConnection":
183
191
self .conn .__enter__ ()
184
192
return self
185
193
186
- def __exit__ (self , exc_type , exc_value , traceback ) -> Optional [bool ]:
194
+ def __exit__ (
195
+ self ,
196
+ exc_type : Optional [Type [BaseException ]],
197
+ exc_value : Optional [BaseException ],
198
+ traceback : Optional [types .TracebackType ],
199
+ ) -> Optional [bool ]:
187
200
return self .conn .__exit__ (exc_type , exc_value , traceback )
188
201
189
202
# Proxy through any unknown lookups to the DB conn class.
190
- def __getattr__ (self , name ) :
203
+ def __getattr__ (self , name : str ) -> Any :
191
204
return getattr (self .conn , name )
192
205
193
206
@@ -391,17 +404,22 @@ def close(self) -> None:
391
404
def __enter__ (self ) -> "LoggingTransaction" :
392
405
return self
393
406
394
- def __exit__ (self , exc_type , exc_value , traceback ):
407
+ def __exit__ (
408
+ self ,
409
+ exc_type : Optional [Type [BaseException ]],
410
+ exc_value : Optional [BaseException ],
411
+ traceback : Optional [types .TracebackType ],
412
+ ) -> None :
395
413
self .close ()
396
414
397
415
398
416
class PerformanceCounters :
399
- def __init__ (self ):
400
- self .current_counters = {}
401
- self .previous_counters = {}
417
+ def __init__ (self ) -> None :
418
+ self .current_counters : Dict [ str , Tuple [ int , float ]] = {}
419
+ self .previous_counters : Dict [ str , Tuple [ int , float ]] = {}
402
420
403
421
def update (self , key : str , duration_secs : float ) -> None :
404
- count , cum_time = self .current_counters .get (key , (0 , 0 ))
422
+ count , cum_time = self .current_counters .get (key , (0 , 0.0 ))
405
423
count += 1
406
424
cum_time += duration_secs
407
425
self .current_counters [key ] = (count , cum_time )
@@ -527,7 +545,7 @@ async def _check_safe_to_upsert(self) -> None:
527
545
def start_profiling (self ) -> None :
528
546
self ._previous_loop_ts = monotonic_time ()
529
547
530
- def loop ():
548
+ def loop () -> None :
531
549
curr = self ._current_txn_total_time
532
550
prev = self ._previous_txn_total_time
533
551
self ._previous_txn_total_time = curr
@@ -1186,7 +1204,7 @@ def simple_upsert_txn_emulated(
1186
1204
if lock :
1187
1205
self .engine .lock_table (txn , table )
1188
1206
1189
- def _getwhere (key ) :
1207
+ def _getwhere (key : str ) -> str :
1190
1208
# If the value we're passing in is None (aka NULL), we need to use
1191
1209
# IS, not =, as NULL = NULL equals NULL (False).
1192
1210
if keyvalues [key ] is None :
@@ -2258,7 +2276,7 @@ async def simple_search_list(
2258
2276
term : Optional [str ],
2259
2277
col : str ,
2260
2278
retcols : Collection [str ],
2261
- desc = "simple_search_list" ,
2279
+ desc : str = "simple_search_list" ,
2262
2280
) -> Optional [List [Dict [str , Any ]]]:
2263
2281
"""Executes a SELECT query on the named table, which may return zero or
2264
2282
more rows, returning the result as a list of dicts.
0 commit comments