Skip to content

Commit 5cd6523

Browse files
vishalyahashhar
authored andcommitted
Support full Oauth2 flow for authentication
- Use response hooks for the reactive flow - Send a http request to the statement url - Get the auth url - the caller can setup a urlhandler (to load the url in the browser) - default urlhandler is just going to print the url on stdout - user can copypaste the url in the browser and perform the external authentication - Get the token send in the response header and use it for authentication On how to use the new auth class for OAuth2 - https://github.com/trinodb/trino-python-client/tree/oauth_dev#oauth2-authentication For high level requirement, please refer to - #103
1 parent 5344728 commit 5cd6523

File tree

5 files changed

+374
-43
lines changed

5 files changed

+374
-43
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ cur = conn.cursor()
7676
cur.execute('SELECT * FROM system.runtime.nodes')
7777
rows = cur.fetchall()
7878
```
79+
# OAuth2 Authentication
80+
- It can be used for the Oauth2 enabled trino server https://trino.io/docs/current/security/oauth2.html
81+
- A callback to handle the redirect url can be provided via param redirect_auth_url_handler, by default it just outputs the redirect url to stdout
82+
```python
83+
import trino
84+
conn = trino.dbapi.connect(
85+
host='coordinator-url',
86+
port=8443,
87+
user='the-user',
88+
catalog='the-catalog',
89+
schema='the-schema',
90+
http_scheme='https',
91+
auth=trino.auth.OAuth2Authentication(),
92+
)
93+
cur = conn.cursor()
94+
cur.execute('SELECT * FROM system.runtime.nodes')
95+
rows = cur.fetchall()
96+
```
7997

8098
# Transactions
8199
The client runs by default in *autocommit* mode. To enable transactions, set

tests/unit/test_client.py

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from requests_kerberos.exceptions import KerberosExchangeError
2222
from trino.client import TrinoQuery, TrinoRequest, TrinoResult
23-
from trino.auth import KerberosAuthentication
23+
from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
2424
from trino import constants
2525
import trino.exceptions
2626

@@ -254,6 +254,227 @@ def long_call(request, uri, headers):
254254
httpretty.reset()
255255

256256

257+
OAUTH_SERVER_URL_NO_HEADER = "http://coordinator/no_header"
258+
OAUTH_SERVER_URL_FAIL_SERVER = "http://coordinator/fail_server"
259+
OAUTH_SERVER_URL_SERVER_DENIED = "http://coordinator/server_denied_accesss"
260+
OAUTH_SERVER_URL_SERVER_SUCCESS = "http://coordinator/statement_url_suceess"
261+
OAUTH_REDIRECT_SERVER = "https://coordinator/as/authorization.oauth2"
262+
OAUTH_SERVER_URL_LOOP = "https://coordinator/oauth2/token/loop"
263+
OAUTH_SERVER_URL_1 = "https://coordinator/oauth2/token/13b03a96-1311-43eb-ada1-a2a9746f7281"
264+
OAUTH_SERVER_URL_2 = "https://coordinator/oauth2/token/e71970b6-d1e7-447e-8d82-9325d3f6192d"
265+
OAUTH_SERVER_URL_FORCE_FAIL = "https://coordinator/oauth2/token/force_fail"
266+
OAUTH_SERVER_URL_DENY_ACCESS = "https://coordinator/oauth2/token/deny_access"
267+
OAUTH_DENY_ERROR_TEXT = '{"error": "OAuth server returned an error: error=access_denied, error_description=null, error_uri=null, state=EncodedState"}' # NOQA: E501
268+
OAUTH_TEST_TOKEN = "FakeToken1234567890"
269+
270+
271+
def oauth2_test_url_handler(url):
272+
print(url, end='')
273+
274+
275+
class OAuthTestReq:
276+
def __init__(self, method, url):
277+
self.method = method
278+
self.url = url
279+
280+
def __call__(self, str, callback_func):
281+
if (self.method == 'post'):
282+
callback_func(self.get_statement_post_response())
283+
elif (self.method == 'get'):
284+
callback_func(self.get_token_url_response())
285+
286+
def get_statement_request(self):
287+
req = mock.Mock()
288+
req.url = self.url
289+
req.headers = requests.structures.CaseInsensitiveDict()
290+
req.register_hook = mock.Mock(side_effect=self)
291+
return req
292+
293+
def get_token_request(self):
294+
req = mock.Mock()
295+
req.url = self.url
296+
req.headers = requests.structures.CaseInsensitiveDict()
297+
req.register_hook = mock.Mock(side_effect=self)
298+
return req
299+
300+
def get_statement_post_response(self):
301+
statement_resp = mock.Mock()
302+
statement_resp.status_code = 401
303+
if (self.url == OAUTH_SERVER_URL_NO_HEADER):
304+
statement_resp.headers = requests.structures.CaseInsensitiveDict()
305+
elif (self.url == OAUTH_SERVER_URL_FAIL_SERVER):
306+
statement_resp.headers = requests.structures.CaseInsensitiveDict([
307+
('Www-Authenticate',
308+
'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",'
309+
f'x_token_server=\"{OAUTH_SERVER_URL_FORCE_FAIL}\",'
310+
'Basic realm=\"Trino\"')])
311+
elif (self.url == OAUTH_SERVER_URL_SERVER_DENIED):
312+
statement_resp.headers = requests.structures.CaseInsensitiveDict([
313+
('Www-Authenticate',
314+
'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",'
315+
f'x_token_server=\"{OAUTH_SERVER_URL_DENY_ACCESS}\",'
316+
'Basic realm=\"Trino\"')])
317+
elif (self.url == OAUTH_SERVER_URL_SERVER_SUCCESS):
318+
statement_resp.status_code = 200
319+
statement_resp.headers = requests.structures.CaseInsensitiveDict([
320+
('Www-Authenticate',
321+
f'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",'
322+
f'x_token_server=\"{OAUTH_SERVER_URL_1}\",'
323+
'Basic realm=\"Trino\"')])
324+
else:
325+
statement_resp.headers = requests.structures.CaseInsensitiveDict([
326+
('Www-Authenticate',
327+
f'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",'
328+
f'x_token_server=\"{OAUTH_SERVER_URL_1}\",'
329+
'Basic realm=\"Trino\"')])
330+
331+
statement_resp.register_hook = mock.Mock(side_effect=self)
332+
statement_resp.url = self.url
333+
return statement_resp
334+
335+
def get_token_url_response(self):
336+
token_resp = mock.Mock()
337+
token_resp.status_code = 200
338+
339+
# Success cases
340+
if self.url == OAUTH_SERVER_URL_1:
341+
token_resp.text = f'{{"nextUri":"{OAUTH_SERVER_URL_2}"}}'
342+
elif self.url == OAUTH_SERVER_URL_2:
343+
token_resp.text = f'{{"token":"{OAUTH_TEST_TOKEN}"}}'
344+
345+
# Failure cases
346+
elif self.url == OAUTH_SERVER_URL_FORCE_FAIL:
347+
token_resp.status_code = 500
348+
elif self.url == OAUTH_SERVER_URL_DENY_ACCESS:
349+
token_resp.text = OAUTH_DENY_ERROR_TEXT
350+
elif self.url == OAUTH_SERVER_URL_LOOP:
351+
token_resp.text = f'{{"nextUri":"{OAUTH_SERVER_URL_LOOP}"}}'
352+
353+
return token_resp
354+
355+
356+
def call_response_hook(str, callback_func):
357+
statement_resp = mock.Mock()
358+
statement_resp.headers = requests.structures.CaseInsensitiveDict([
359+
('Www-Authenticate',
360+
f'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",'
361+
f'x_token_server=\"{OAUTH_SERVER_URL_1}\",'
362+
'Basic realm=\"Trino\"')])
363+
statement_resp.status_code = 401
364+
callback_func(statement_resp)
365+
366+
367+
@mock.patch("requests.Session.get")
368+
@mock.patch("requests.Session.post")
369+
def test_oauth2_authentication_flow(http_session_post, http_session_get, capsys):
370+
http_session = requests.Session()
371+
372+
# set up the patched session, with the correct response
373+
oauth_test = OAuthTestReq("post", "http://coordinator/statement_url")
374+
http_session_post.return_value = oauth_test.get_statement_post_response()
375+
http_session_get.side_effect = oauth_test.get_token_url_response()
376+
oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler)
377+
378+
statement_req = oauth_test.get_statement_request()
379+
oauth(statement_req)
380+
381+
oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_1)
382+
token_req = oauth_test.get_token_request()
383+
oauth(token_req)
384+
385+
oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_2)
386+
token_req = oauth_test.get_token_request()
387+
oauth(token_req)
388+
389+
# Finally resend the original request, and respond back with status code 200
390+
oauth_test = OAuthTestReq("post", "http://coordinator/statement_url_suceess")
391+
# statement_req.register_hook = mock.Mock(side_effect=oauth_test)
392+
statement_req = oauth_test.get_statement_request()
393+
http_session_post.return_value = oauth_test.get_statement_post_response()
394+
oauth(statement_req)
395+
396+
out, err = capsys.readouterr()
397+
assert out == OAUTH_REDIRECT_SERVER
398+
assert statement_req.headers['Authorization'] == "Bearer " + OAUTH_TEST_TOKEN
399+
400+
401+
@mock.patch("requests.Session.get")
402+
@mock.patch("requests.Session.post")
403+
def test_oauth2_exceed_max_attempts(http_session_post, http_session_get):
404+
http_session = requests.Session()
405+
406+
# set up the patched session, with the correct response
407+
oauth_test = OAuthTestReq("post", "http://coordinator/statement_url")
408+
http_session_post.return_value = oauth_test.get_statement_post_response()
409+
http_session_get.side_effect = oauth_test.get_token_url_response()
410+
oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler)
411+
412+
statement_req = oauth_test.get_statement_request()
413+
oauth(statement_req)
414+
415+
with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
416+
for i in range(0, 5):
417+
oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_1)
418+
token_req = oauth_test.get_token_request()
419+
oauth(token_req)
420+
421+
assert str(exp.value) == "Exceeded max attempts while getting the token"
422+
423+
424+
@mock.patch("requests.Session.post")
425+
def test_oauth2_authentication_missing_headers(http_session_post):
426+
http_session = requests.Session()
427+
oauth_test = OAuthTestReq("post", OAUTH_SERVER_URL_NO_HEADER)
428+
http_session_post.return_value = oauth_test.get_statement_post_response()
429+
oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler)
430+
431+
with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
432+
statement_req = oauth_test.get_statement_request()
433+
oauth(statement_req)
434+
435+
assert str(exp.value) == "Error: header WWW-Authenticate not available in the response."
436+
437+
438+
@mock.patch("requests.Session.get")
439+
@mock.patch("requests.Session.post")
440+
def test_oauth2_authentication_fail_token_server(http_session_post, http_session_get):
441+
http_session = requests.Session()
442+
oauth_test = OAuthTestReq("post", OAUTH_SERVER_URL_FAIL_SERVER)
443+
http_session_post.return_value = oauth_test.get_statement_post_response()
444+
oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler)
445+
http_session_get.side_effect = oauth_test.get_token_url_response()
446+
447+
statement_req = oauth_test.get_statement_request()
448+
oauth(statement_req)
449+
450+
with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
451+
oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_FORCE_FAIL)
452+
token_req = oauth_test.get_token_request()
453+
oauth(token_req)
454+
455+
assert "Error while getting the token response status" in str(exp.value)
456+
457+
458+
@mock.patch("requests.Session.get")
459+
@mock.patch("requests.Session.post")
460+
def test_oauth2_authentication_access_denied(http_session_post, http_session_get):
461+
http_session = requests.Session()
462+
oauth_test = OAuthTestReq("post", OAUTH_SERVER_URL_SERVER_DENIED)
463+
http_session_post.return_value = oauth_test.get_statement_post_response()
464+
oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler)
465+
http_session_get.side_effect = oauth_test.get_token_url_response()
466+
467+
statement_req = oauth_test.get_statement_request()
468+
oauth(statement_req)
469+
470+
with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
471+
oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_FORCE_FAIL)
472+
token_req = oauth_test.get_token_request()
473+
oauth(token_req)
474+
475+
assert "Error while getting the token" in str(exp.value)
476+
477+
257478
@mock.patch("trino.client.TrinoRequest.http")
258479
def test_trino_fetch_request(mock_requests, sample_get_response_data):
259480
mock_requests.Response.return_value.json.return_value = sample_get_response_data

0 commit comments

Comments
 (0)