@@ -58,24 +58,28 @@ class PostgreSQLOnlineStore(OnlineStore):
58
58
_conn_pool_async : Optional [AsyncConnectionPool ] = None
59
59
60
60
@contextlib .contextmanager
61
- def _get_conn (self , config : RepoConfig ) -> Generator [Connection , Any , Any ]:
61
+ def _get_conn (
62
+ self , config : RepoConfig , autocommit : bool = False
63
+ ) -> Generator [Connection , Any , Any ]:
62
64
assert config .online_store .type == "postgres"
63
65
64
66
if config .online_store .conn_type == ConnectionType .pool :
65
67
if not self ._conn_pool :
66
68
self ._conn_pool = _get_connection_pool (config .online_store )
67
69
self ._conn_pool .open ()
68
70
connection = self ._conn_pool .getconn ()
71
+ connection .set_autocommit (autocommit )
69
72
yield connection
70
73
self ._conn_pool .putconn (connection )
71
74
else :
72
75
if not self ._conn :
73
76
self ._conn = _get_conn (config .online_store )
77
+ self ._conn .set_autocommit (autocommit )
74
78
yield self ._conn
75
79
76
80
@contextlib .asynccontextmanager
77
81
async def _get_conn_async (
78
- self , config : RepoConfig
82
+ self , config : RepoConfig , autocommit : bool = False
79
83
) -> AsyncGenerator [AsyncConnection , Any ]:
80
84
if config .online_store .conn_type == ConnectionType .pool :
81
85
if not self ._conn_pool_async :
@@ -84,11 +88,13 @@ async def _get_conn_async(
84
88
)
85
89
await self ._conn_pool_async .open ()
86
90
connection = await self ._conn_pool_async .getconn ()
91
+ await connection .set_autocommit (autocommit )
87
92
yield connection
88
93
await self ._conn_pool_async .putconn (connection )
89
94
else :
90
95
if not self ._conn_async :
91
96
self ._conn_async = await _get_conn_async (config .online_store )
97
+ await self ._conn_async .set_autocommit (autocommit )
92
98
yield self ._conn_async
93
99
94
100
def online_write_batch (
@@ -161,7 +167,7 @@ def online_read(
161
167
config , table , keys , requested_features
162
168
)
163
169
164
- with self ._get_conn (config ) as conn , conn .cursor () as cur :
170
+ with self ._get_conn (config , autocommit = True ) as conn , conn .cursor () as cur :
165
171
cur .execute (query , params )
166
172
rows = cur .fetchall ()
167
173
@@ -179,7 +185,7 @@ async def online_read_async(
179
185
config , table , keys , requested_features
180
186
)
181
187
182
- async with self ._get_conn_async (config ) as conn :
188
+ async with self ._get_conn_async (config , autocommit = True ) as conn :
183
189
async with conn .cursor () as cur :
184
190
await cur .execute (query , params )
185
191
rows = await cur .fetchall ()
@@ -339,6 +345,7 @@ def teardown(
339
345
for table in tables :
340
346
table_name = _table_id (project , table )
341
347
cur .execute (_drop_table_and_index (table_name ))
348
+ conn .commit ()
342
349
except Exception :
343
350
logging .exception ("Teardown failed" )
344
351
raise
@@ -398,7 +405,7 @@ def retrieve_online_documents(
398
405
Optional [ValueProto ],
399
406
]
400
407
] = []
401
- with self ._get_conn (config ) as conn , conn .cursor () as cur :
408
+ with self ._get_conn (config , autocommit = True ) as conn , conn .cursor () as cur :
402
409
table_name = _table_id (project , table )
403
410
404
411
# Search query template to find the top k items that are closest to the given embedding
0 commit comments