Skip to content

Commit 864f7cf

Browse files
Add tests; maybe fix?
1 parent 721ab44 commit 864f7cf

File tree

7 files changed

+960
-5
lines changed

7 files changed

+960
-5
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,11 @@ async def authenticate_with_retry(self, auth_instance) -> None:
804804
# SSO if it has expired
805805
await self._reauthenticate()
806806
else:
807-
await self._authenticate(auth_instance)
807+
# TODO pczajka: check if this is correct
808+
# For OAuth and other auth types, call their reauthenticate method
809+
await auth_instance.reauthenticate(conn=self)
810+
# The reauthenticate method will call authenticate_with_retry internally,
811+
# so we don't need to call _authenticate again here
808812

809813
async def autocommit(self, mode) -> None:
810814
"""Sets autocommit mode to True, or False. Defaults to True."""

src/snowflake/connector/aio/auth/_auth.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ACCEPT_TYPE_APPLICATION_SNOWFLAKE,
3131
CONTENT_TYPE_APPLICATION_JSON,
3232
ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE,
33+
OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE,
3334
PYTHON_CONNECTOR_USER_AGENT,
3435
ReauthenticationRequest,
3536
)
@@ -282,6 +283,15 @@ async def post_request_wrapper(self, url, headers, body) -> None:
282283
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
283284
)
284285
)
286+
elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE:
287+
raise ReauthenticationRequest(
288+
ProgrammingError(
289+
msg=ret["message"],
290+
errno=int(errno),
291+
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
292+
)
293+
)
294+
285295
from . import AuthByKeyPair
286296

287297
if isinstance(auth_instance, AuthByKeyPair):

src/snowflake/connector/aio/auth/_oauth_code.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
77

88
from ...auth.oauth_code import AuthByOauthCode as AuthByOauthCodeSync
99
from ...token_cache import TokenCache
1010
from ._by_plugin import AuthByPlugin as AuthByPluginAsync
1111

12+
if TYPE_CHECKING:
13+
from .. import SnowflakeConnection
14+
1215
logger = logging.getLogger(__name__)
1316

1417

@@ -57,7 +60,47 @@ async def prepare(self, **kwargs: Any) -> None:
5760
AuthByOauthCodeSync.prepare(self, **kwargs)
5861

5962
async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
60-
return AuthByOauthCodeSync.reauthenticate(self, **kwargs)
63+
"""Override to use async connection properly."""
64+
# TODO pczajka: check if this is correct
65+
66+
# Call the sync reset logic but handle the connection retry ourselves
67+
self._reset_access_token()
68+
if self._pop_cached_refresh_token():
69+
logger.debug(
70+
"OAuth refresh token is available, try to use it and get a new access token"
71+
)
72+
self._do_refresh_token(conn=kwargs.get("conn"))
73+
# Use async authenticate_with_retry
74+
if "conn" in kwargs:
75+
await kwargs["conn"].authenticate_with_retry(self)
76+
return {"success": True}
6177

6278
async def update_body(self, body: dict[Any, Any]) -> None:
6379
AuthByOauthCodeSync.update_body(self, body)
80+
81+
def _handle_failure(
82+
self,
83+
*,
84+
conn: SnowflakeConnection,
85+
ret: dict[Any, Any],
86+
**kwargs: Any,
87+
) -> None:
88+
"""Override to ensure proper error handling in async context."""
89+
# Use sync error handling directly to avoid async/sync mismatch
90+
from ...errors import DatabaseError, Error
91+
from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
92+
93+
Error.errorhandler_wrapper(
94+
conn,
95+
None,
96+
DatabaseError,
97+
{
98+
"msg": "Failed to connect to DB: {host}:{port}, {message}".format(
99+
host=conn._rest._host,
100+
port=conn._rest._port,
101+
message=ret["message"],
102+
),
103+
"errno": int(ret.get("code", -1)),
104+
"sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
105+
},
106+
)

src/snowflake/connector/aio/auth/_oauth_credentials.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
77

88
from ...auth.oauth_credentials import (
99
AuthByOauthCredentials as AuthByOauthCredentialsSync,
1010
)
1111
from ...token_cache import TokenCache
1212
from ._by_plugin import AuthByPlugin as AuthByPluginAsync
1313

14+
if TYPE_CHECKING:
15+
from .. import SnowflakeConnection
16+
1417
logger = logging.getLogger(__name__)
1518

1619

@@ -51,7 +54,47 @@ async def prepare(self, **kwargs: Any) -> None:
5154
AuthByOauthCredentialsSync.prepare(self, **kwargs)
5255

5356
async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
54-
return AuthByOauthCredentialsSync.reauthenticate(self, **kwargs)
57+
"""Override to use async connection properly."""
58+
# TODO pczajka: check if this is correct
59+
60+
# Call the sync reset logic but handle the connection retry ourselves
61+
self._reset_access_token()
62+
if self._pop_cached_refresh_token():
63+
logger.debug(
64+
"OAuth refresh token is available, try to use it and get a new access token"
65+
)
66+
self._do_refresh_token(conn=kwargs.get("conn"))
67+
# Use async authenticate_with_retry
68+
if "conn" in kwargs:
69+
await kwargs["conn"].authenticate_with_retry(self)
70+
return {"success": True}
5571

5672
async def update_body(self, body: dict[Any, Any]) -> None:
5773
AuthByOauthCredentialsSync.update_body(self, body)
74+
75+
def _handle_failure(
76+
self,
77+
*,
78+
conn: SnowflakeConnection,
79+
ret: dict[Any, Any],
80+
**kwargs: Any,
81+
) -> None:
82+
"""Override to ensure proper error handling in async context."""
83+
# Use sync error handling directly to avoid async/sync mismatch
84+
from ...errors import DatabaseError, Error
85+
from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
86+
87+
Error.errorhandler_wrapper(
88+
conn,
89+
None,
90+
DatabaseError,
91+
{
92+
"msg": "Failed to connect to DB: {host}:{port}, {message}".format(
93+
host=conn._rest._host,
94+
port=conn._rest._port,
95+
message=ret["message"],
96+
),
97+
"errno": int(ret.get("code", -1)),
98+
"sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
99+
},
100+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
#
5+
6+
from __future__ import annotations
7+
8+
import os
9+
10+
from snowflake.connector.aio.auth import AuthByOauthCode
11+
12+
13+
async def test_auth_oauth_code():
14+
"""Simple OAuth Code test."""
15+
# Set experimental auth flag for the test
16+
os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] = "true"
17+
18+
auth = AuthByOauthCode(
19+
application="test_app",
20+
client_id="test_client_id",
21+
client_secret="test_client_secret",
22+
authentication_url="https://example.com/auth",
23+
token_request_url="https://example.com/token",
24+
redirect_uri="http://localhost:8080/callback",
25+
scope="session:role:test_role",
26+
pkce_enabled=True,
27+
refresh_token_enabled=False,
28+
)
29+
30+
body = {"data": {}}
31+
await auth.update_body(body)
32+
33+
# Check that OAuth authenticator is set
34+
assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
35+
# OAuth type should be set to authorization_code
36+
assert body["data"]["OAUTH_TYPE"] == "authorization_code", body
37+
38+
# Clean up environment variable
39+
del os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"]
40+
41+
42+
def test_mro():
43+
"""Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
44+
from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync
45+
from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync
46+
47+
assert AuthByOauthCode.mro().index(AuthByPluginAsync) < AuthByOauthCode.mro().index(
48+
AuthByPluginSync
49+
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
#
5+
6+
from __future__ import annotations
7+
8+
import os
9+
10+
from snowflake.connector.aio.auth import AuthByOauthCredentials
11+
12+
13+
async def test_auth_oauth_credentials():
14+
"""Simple OAuth Credentials test."""
15+
# Set experimental auth flag for the test
16+
os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] = "true"
17+
18+
auth = AuthByOauthCredentials(
19+
application="test_app",
20+
client_id="test_client_id",
21+
client_secret="test_client_secret",
22+
token_request_url="https://example.com/token",
23+
scope="session:role:test_role",
24+
refresh_token_enabled=False,
25+
)
26+
27+
body = {"data": {}}
28+
await auth.update_body(body)
29+
30+
# Check that OAuth authenticator is set
31+
assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
32+
# OAuth type should be set to client_credentials
33+
assert body["data"]["OAUTH_TYPE"] == "client_credentials", body
34+
35+
# Clean up environment variable
36+
del os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"]
37+
38+
39+
def test_mro():
40+
"""Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
41+
from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync
42+
from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync
43+
44+
assert AuthByOauthCredentials.mro().index(
45+
AuthByPluginAsync
46+
) < AuthByOauthCredentials.mro().index(AuthByPluginSync)

0 commit comments

Comments
 (0)