|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | import logging
|
6 |
| -from typing import Any |
| 6 | +from typing import TYPE_CHECKING, Any |
7 | 7 |
|
8 | 8 | from ...auth.oauth_credentials import (
|
9 | 9 | AuthByOauthCredentials as AuthByOauthCredentialsSync,
|
10 | 10 | )
|
11 | 11 | from ...token_cache import TokenCache
|
12 | 12 | from ._by_plugin import AuthByPlugin as AuthByPluginAsync
|
13 | 13 |
|
| 14 | +if TYPE_CHECKING: |
| 15 | + from .. import SnowflakeConnection |
| 16 | + |
14 | 17 | logger = logging.getLogger(__name__)
|
15 | 18 |
|
16 | 19 |
|
@@ -51,7 +54,47 @@ async def prepare(self, **kwargs: Any) -> None:
|
51 | 54 | AuthByOauthCredentialsSync.prepare(self, **kwargs)
|
52 | 55 |
|
53 | 56 | async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
|
54 |
| - return AuthByOauthCredentialsSync.reauthenticate(self, **kwargs) |
| 57 | + """Override to use async connection properly.""" |
| 58 | + # TODO pczajka: check if this is correct |
| 59 | + |
| 60 | + # Call the sync reset logic but handle the connection retry ourselves |
| 61 | + self._reset_access_token() |
| 62 | + if self._pop_cached_refresh_token(): |
| 63 | + logger.debug( |
| 64 | + "OAuth refresh token is available, try to use it and get a new access token" |
| 65 | + ) |
| 66 | + self._do_refresh_token(conn=kwargs.get("conn")) |
| 67 | + # Use async authenticate_with_retry |
| 68 | + if "conn" in kwargs: |
| 69 | + await kwargs["conn"].authenticate_with_retry(self) |
| 70 | + return {"success": True} |
55 | 71 |
|
56 | 72 | async def update_body(self, body: dict[Any, Any]) -> None:
|
57 | 73 | AuthByOauthCredentialsSync.update_body(self, body)
|
| 74 | + |
| 75 | + def _handle_failure( |
| 76 | + self, |
| 77 | + *, |
| 78 | + conn: SnowflakeConnection, |
| 79 | + ret: dict[Any, Any], |
| 80 | + **kwargs: Any, |
| 81 | + ) -> None: |
| 82 | + """Override to ensure proper error handling in async context.""" |
| 83 | + # Use sync error handling directly to avoid async/sync mismatch |
| 84 | + from ...errors import DatabaseError, Error |
| 85 | + from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED |
| 86 | + |
| 87 | + Error.errorhandler_wrapper( |
| 88 | + conn, |
| 89 | + None, |
| 90 | + DatabaseError, |
| 91 | + { |
| 92 | + "msg": "Failed to connect to DB: {host}:{port}, {message}".format( |
| 93 | + host=conn._rest._host, |
| 94 | + port=conn._rest._port, |
| 95 | + message=ret["message"], |
| 96 | + ), |
| 97 | + "errno": int(ret.get("code", -1)), |
| 98 | + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, |
| 99 | + }, |
| 100 | + ) |
0 commit comments