Skip to content

Commit 46f8d7a

Browse files
fix: Ensure Postgres queries are committed or autocommit is used (#5039)
1 parent d937dcb commit 46f8d7a

File tree

1 file changed

+12
-5
lines changed
  • sdk/python/feast/infra/online_stores/postgres_online_store

1 file changed

+12
-5
lines changed

sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,28 @@ class PostgreSQLOnlineStore(OnlineStore):
5858
_conn_pool_async: Optional[AsyncConnectionPool] = None
5959

6060
@contextlib.contextmanager
61-
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
61+
def _get_conn(
62+
self, config: RepoConfig, autocommit: bool = False
63+
) -> Generator[Connection, Any, Any]:
6264
assert config.online_store.type == "postgres"
6365

6466
if config.online_store.conn_type == ConnectionType.pool:
6567
if not self._conn_pool:
6668
self._conn_pool = _get_connection_pool(config.online_store)
6769
self._conn_pool.open()
6870
connection = self._conn_pool.getconn()
71+
connection.set_autocommit(autocommit)
6972
yield connection
7073
self._conn_pool.putconn(connection)
7174
else:
7275
if not self._conn:
7376
self._conn = _get_conn(config.online_store)
77+
self._conn.set_autocommit(autocommit)
7478
yield self._conn
7579

7680
@contextlib.asynccontextmanager
7781
async def _get_conn_async(
78-
self, config: RepoConfig
82+
self, config: RepoConfig, autocommit: bool = False
7983
) -> AsyncGenerator[AsyncConnection, Any]:
8084
if config.online_store.conn_type == ConnectionType.pool:
8185
if not self._conn_pool_async:
@@ -84,11 +88,13 @@ async def _get_conn_async(
8488
)
8589
await self._conn_pool_async.open()
8690
connection = await self._conn_pool_async.getconn()
91+
await connection.set_autocommit(autocommit)
8792
yield connection
8893
await self._conn_pool_async.putconn(connection)
8994
else:
9095
if not self._conn_async:
9196
self._conn_async = await _get_conn_async(config.online_store)
97+
await self._conn_async.set_autocommit(autocommit)
9298
yield self._conn_async
9399

94100
def online_write_batch(
@@ -161,7 +167,7 @@ def online_read(
161167
config, table, keys, requested_features
162168
)
163169

164-
with self._get_conn(config) as conn, conn.cursor() as cur:
170+
with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
165171
cur.execute(query, params)
166172
rows = cur.fetchall()
167173

@@ -179,7 +185,7 @@ async def online_read_async(
179185
config, table, keys, requested_features
180186
)
181187

182-
async with self._get_conn_async(config) as conn:
188+
async with self._get_conn_async(config, autocommit=True) as conn:
183189
async with conn.cursor() as cur:
184190
await cur.execute(query, params)
185191
rows = await cur.fetchall()
@@ -339,6 +345,7 @@ def teardown(
339345
for table in tables:
340346
table_name = _table_id(project, table)
341347
cur.execute(_drop_table_and_index(table_name))
348+
conn.commit()
342349
except Exception:
343350
logging.exception("Teardown failed")
344351
raise
@@ -398,7 +405,7 @@ def retrieve_online_documents(
398405
Optional[ValueProto],
399406
]
400407
] = []
401-
with self._get_conn(config) as conn, conn.cursor() as cur:
408+
with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
402409
table_name = _table_id(project, table)
403410

404411
# Search query template to find the top k items that are closest to the given embedding

0 commit comments

Comments
 (0)