13
13
# limitations under the License.
14
14
15
15
import logging
16
- from typing import Mapping , Optional
16
+ from typing import TYPE_CHECKING , Any , Mapping , NoReturn , Optional , Tuple , cast
17
17
18
18
from synapse .storage .engines ._base import (
19
19
BaseDatabaseEngine ,
20
20
IncorrectDatabaseSetup ,
21
21
IsolationLevel ,
22
22
)
23
- from synapse .storage .types import Connection
23
+ from synapse .storage .types import Cursor
24
+
25
+ if TYPE_CHECKING :
26
+ import psycopg2 # noqa: F401
27
+
28
+ from synapse .storage .database import LoggingDatabaseConnection
29
+
24
30
25
31
logger = logging .getLogger (__name__ )
26
32
27
33
28
- class PostgresEngine (BaseDatabaseEngine ):
29
- def __init__ (self , database_module , database_config ):
30
- super ().__init__ (database_module , database_config )
31
- self .module .extensions .register_type (self .module .extensions .UNICODE )
34
+ class PostgresEngine (BaseDatabaseEngine ["psycopg2.connection" ]):
35
+ def __init__ (self , database_config : Mapping [str , Any ]):
36
+ import psycopg2 .extensions
37
+
38
+ super ().__init__ (psycopg2 , database_config )
39
+ psycopg2 .extensions .register_type (psycopg2 .extensions .UNICODE )
32
40
33
41
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
34
42
# actually want to use bytes than wrap it in `bytearray`.
35
- def _disable_bytes_adapter (_ ) :
43
+ def _disable_bytes_adapter (_ : bytes ) -> NoReturn :
36
44
raise Exception ("Passing bytes to DB is disabled." )
37
45
38
- self . module .extensions .register_adapter (bytes , _disable_bytes_adapter )
39
- self .synchronous_commit = database_config .get ("synchronous_commit" , True )
40
- self ._version = None # unknown as yet
46
+ psycopg2 .extensions .register_adapter (bytes , _disable_bytes_adapter )
47
+ self .synchronous_commit : bool = database_config .get ("synchronous_commit" , True )
48
+ self ._version : Optional [ int ] = None # unknown as yet
41
49
42
50
self .isolation_level_map : Mapping [int , int ] = {
43
- IsolationLevel .READ_COMMITTED : self . module .extensions .ISOLATION_LEVEL_READ_COMMITTED ,
44
- IsolationLevel .REPEATABLE_READ : self . module .extensions .ISOLATION_LEVEL_REPEATABLE_READ ,
45
- IsolationLevel .SERIALIZABLE : self . module .extensions .ISOLATION_LEVEL_SERIALIZABLE ,
51
+ IsolationLevel .READ_COMMITTED : psycopg2 .extensions .ISOLATION_LEVEL_READ_COMMITTED ,
52
+ IsolationLevel .REPEATABLE_READ : psycopg2 .extensions .ISOLATION_LEVEL_REPEATABLE_READ ,
53
+ IsolationLevel .SERIALIZABLE : psycopg2 .extensions .ISOLATION_LEVEL_SERIALIZABLE ,
46
54
}
47
55
self .default_isolation_level = (
48
- self . module .extensions .ISOLATION_LEVEL_REPEATABLE_READ
56
+ psycopg2 .extensions .ISOLATION_LEVEL_REPEATABLE_READ
49
57
)
50
58
self .config = database_config
51
59
52
60
@property
53
61
def single_threaded (self ) -> bool :
54
62
return False
55
63
56
- def get_db_locale (self , txn ) :
64
+ def get_db_locale (self , txn : Cursor ) -> Tuple [ str , str ] :
57
65
txn .execute (
58
66
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
59
67
)
60
- collation , ctype = txn .fetchone ()
68
+ collation , ctype = cast ( Tuple [ str , str ], txn .fetchone () )
61
69
return collation , ctype
62
70
63
- def check_database (self , db_conn , allow_outdated_version : bool = False ):
71
+ def check_database (
72
+ self , db_conn : "psycopg2.connection" , allow_outdated_version : bool = False
73
+ ) -> None :
64
74
# Get the version of PostgreSQL that we're using. As per the psycopg2
65
75
# docs: The number is formed by converting the major, minor, and
66
76
# revision numbers into two-decimal-digit numbers and appending them
67
77
# together. For example, version 8.1.5 will be returned as 80105
68
- self ._version = db_conn .server_version
78
+ self ._version = cast ( int , db_conn .server_version )
69
79
allow_unsafe_locale = self .config .get ("allow_unsafe_locale" , False )
70
80
71
81
# Are we on a supported PostgreSQL version?
@@ -108,7 +118,7 @@ def check_database(self, db_conn, allow_outdated_version: bool = False):
108
118
ctype ,
109
119
)
110
120
111
- def check_new_database (self , txn ) :
121
+ def check_new_database (self , txn : Cursor ) -> None :
112
122
"""Gets called when setting up a brand new database. This allows us to
113
123
apply stricter checks on new databases versus existing database.
114
124
"""
@@ -129,10 +139,10 @@ def check_new_database(self, txn):
129
139
"See docs/postgres.md for more information." % ("\n " .join (errors ))
130
140
)
131
141
132
- def convert_param_style (self , sql ) :
142
+ def convert_param_style (self , sql : str ) -> str :
133
143
return sql .replace ("?" , "%s" )
134
144
135
- def on_new_connection (self , db_conn ) :
145
+ def on_new_connection (self , db_conn : "LoggingDatabaseConnection" ) -> None :
136
146
db_conn .set_isolation_level (self .default_isolation_level )
137
147
138
148
# Set the bytea output to escape, vs the default of hex
@@ -149,14 +159,14 @@ def on_new_connection(self, db_conn):
149
159
db_conn .commit ()
150
160
151
161
@property
152
- def can_native_upsert (self ):
162
+ def can_native_upsert (self ) -> bool :
153
163
"""
154
164
Can we use native UPSERTs?
155
165
"""
156
166
return True
157
167
158
168
@property
159
- def supports_using_any_list (self ):
169
+ def supports_using_any_list (self ) -> bool :
160
170
"""Do we support using `a = ANY(?)` and passing a list"""
161
171
return True
162
172
@@ -165,27 +175,25 @@ def supports_returning(self) -> bool:
165
175
"""Do we support the `RETURNING` clause in insert/update/delete?"""
166
176
return True
167
177
168
- def is_deadlock (self , error ):
169
- if isinstance (error , self .module .DatabaseError ):
178
+ def is_deadlock (self , error : Exception ) -> bool :
179
+ import psycopg2 .extensions
180
+
181
+ if isinstance (error , psycopg2 .DatabaseError ):
170
182
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
171
183
# "40001" serialization_failure
172
184
# "40P01" deadlock_detected
173
185
return error .pgcode in ["40001" , "40P01" ]
174
186
return False
175
187
176
- def is_connection_closed (self , conn ) :
188
+ def is_connection_closed (self , conn : "psycopg2.connection" ) -> bool :
177
189
return bool (conn .closed )
178
190
179
- def lock_table (self , txn , table ) :
191
+ def lock_table (self , txn : Cursor , table : str ) -> None :
180
192
txn .execute ("LOCK TABLE %s in EXCLUSIVE MODE" % (table ,))
181
193
182
194
@property
183
- def server_version (self ):
184
- """Returns a string giving the server version. For example: '8.1.5'
185
-
186
- Returns:
187
- string
188
- """
195
+ def server_version (self ) -> str :
196
+ """Returns a string giving the server version. For example: '8.1.5'."""
189
197
# note that this is a bit of a hack because it relies on check_database
190
198
# having been called. Still, that should be a safe bet here.
191
199
numver = self ._version
@@ -197,17 +205,21 @@ def server_version(self):
197
205
else :
198
206
return "%i.%i.%i" % (numver / 10000 , (numver % 10000 ) / 100 , numver % 100 )
199
207
200
- def in_transaction (self , conn : Connection ) -> bool :
201
- return conn .status != self .module .extensions .STATUS_READY # type: ignore
208
+ def in_transaction (self , conn : "psycopg2.connection" ) -> bool :
209
+ import psycopg2 .extensions
210
+
211
+ return conn .status != psycopg2 .extensions .STATUS_READY
202
212
203
- def attempt_to_set_autocommit (self , conn : Connection , autocommit : bool ):
204
- return conn .set_session (autocommit = autocommit ) # type: ignore
213
+ def attempt_to_set_autocommit (
214
+ self , conn : "psycopg2.connection" , autocommit : bool
215
+ ) -> None :
216
+ return conn .set_session (autocommit = autocommit )
205
217
206
218
def attempt_to_set_isolation_level (
207
- self , conn : Connection , isolation_level : Optional [int ]
208
- ):
219
+ self , conn : "psycopg2.connection" , isolation_level : Optional [int ]
220
+ ) -> None :
209
221
if isolation_level is None :
210
222
isolation_level = self .default_isolation_level
211
223
else :
212
224
isolation_level = self .isolation_level_map [isolation_level ]
213
- return conn .set_isolation_level (isolation_level ) # type: ignore
225
+ return conn .set_isolation_level (isolation_level )
0 commit comments