|
20 | 20 |
|
21 | 21 | from requests_kerberos.exceptions import KerberosExchangeError |
22 | 22 | from trino.client import TrinoQuery, TrinoRequest, TrinoResult |
23 | | -from trino.auth import KerberosAuthentication |
| 23 | +from trino.auth import KerberosAuthentication, _OAuth2TokenBearer |
24 | 24 | from trino import constants |
25 | 25 | import trino.exceptions |
26 | 26 |
|
@@ -254,6 +254,227 @@ def long_call(request, uri, headers): |
254 | 254 | httpretty.reset() |
255 | 255 |
|
256 | 256 |
|
| 257 | +OAUTH_SERVER_URL_NO_HEADER = "http://coordinator/no_header" |
| 258 | +OAUTH_SERVER_URL_FAIL_SERVER = "http://coordinator/fail_server" |
| 259 | +OAUTH_SERVER_URL_SERVER_DENIED = "http://coordinator/server_denied_accesss" |
| 260 | +OAUTH_SERVER_URL_SERVER_SUCCESS = "http://coordinator/statement_url_suceess" |
| 261 | +OAUTH_REDIRECT_SERVER = "https://coordinator/as/authorization.oauth2" |
| 262 | +OAUTH_SERVER_URL_LOOP = "https://coordinator/oauth2/token/loop" |
| 263 | +OAUTH_SERVER_URL_1 = "https://coordinator/oauth2/token/13b03a96-1311-43eb-ada1-a2a9746f7281" |
| 264 | +OAUTH_SERVER_URL_2 = "https://coordinator/oauth2/token/e71970b6-d1e7-447e-8d82-9325d3f6192d" |
| 265 | +OAUTH_SERVER_URL_FORCE_FAIL = "https://coordinator/oauth2/token/force_fail" |
| 266 | +OAUTH_SERVER_URL_DENY_ACCESS = "https://coordinator/oauth2/token/deny_access" |
| 267 | +OAUTH_DENY_ERROR_TEXT = '{"error": "OAuth server returned an error: error=access_denied, error_description=null, error_uri=null, state=EncodedState"}' # NOQA: E501 |
| 268 | +OAUTH_TEST_TOKEN = "FakeToken1234567890" |
| 269 | + |
| 270 | + |
| 271 | +def oauth2_test_url_handler(url): |
| 272 | + print(url, end='') |
| 273 | + |
| 274 | + |
| 275 | +class OAuthTestReq: |
| 276 | + def __init__(self, method, url): |
| 277 | + self.method = method |
| 278 | + self.url = url |
| 279 | + |
| 280 | + def __call__(self, str, callback_func): |
| 281 | + if (self.method == 'post'): |
| 282 | + callback_func(self.get_statement_post_response()) |
| 283 | + elif (self.method == 'get'): |
| 284 | + callback_func(self.get_token_url_response()) |
| 285 | + |
| 286 | + def get_statement_request(self): |
| 287 | + req = mock.Mock() |
| 288 | + req.url = self.url |
| 289 | + req.headers = requests.structures.CaseInsensitiveDict() |
| 290 | + req.register_hook = mock.Mock(side_effect=self) |
| 291 | + return req |
| 292 | + |
| 293 | + def get_token_request(self): |
| 294 | + req = mock.Mock() |
| 295 | + req.url = self.url |
| 296 | + req.headers = requests.structures.CaseInsensitiveDict() |
| 297 | + req.register_hook = mock.Mock(side_effect=self) |
| 298 | + return req |
| 299 | + |
| 300 | + def get_statement_post_response(self): |
| 301 | + statement_resp = mock.Mock() |
| 302 | + statement_resp.status_code = 401 |
| 303 | + if (self.url == OAUTH_SERVER_URL_NO_HEADER): |
| 304 | + statement_resp.headers = requests.structures.CaseInsensitiveDict() |
| 305 | + elif (self.url == OAUTH_SERVER_URL_FAIL_SERVER): |
| 306 | + statement_resp.headers = requests.structures.CaseInsensitiveDict([ |
| 307 | + ('Www-Authenticate', |
| 308 | + 'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",' |
| 309 | + f'x_token_server=\"{OAUTH_SERVER_URL_FORCE_FAIL}\",' |
| 310 | + 'Basic realm=\"Trino\"')]) |
| 311 | + elif (self.url == OAUTH_SERVER_URL_SERVER_DENIED): |
| 312 | + statement_resp.headers = requests.structures.CaseInsensitiveDict([ |
| 313 | + ('Www-Authenticate', |
| 314 | + 'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",' |
| 315 | + f'x_token_server=\"{OAUTH_SERVER_URL_DENY_ACCESS}\",' |
| 316 | + 'Basic realm=\"Trino\"')]) |
| 317 | + elif (self.url == OAUTH_SERVER_URL_SERVER_SUCCESS): |
| 318 | + statement_resp.status_code = 200 |
| 319 | + statement_resp.headers = requests.structures.CaseInsensitiveDict([ |
| 320 | + ('Www-Authenticate', |
| 321 | + f'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",' |
| 322 | + f'x_token_server=\"{OAUTH_SERVER_URL_1}\",' |
| 323 | + 'Basic realm=\"Trino\"')]) |
| 324 | + else: |
| 325 | + statement_resp.headers = requests.structures.CaseInsensitiveDict([ |
| 326 | + ('Www-Authenticate', |
| 327 | + f'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",' |
| 328 | + f'x_token_server=\"{OAUTH_SERVER_URL_1}\",' |
| 329 | + 'Basic realm=\"Trino\"')]) |
| 330 | + |
| 331 | + statement_resp.register_hook = mock.Mock(side_effect=self) |
| 332 | + statement_resp.url = self.url |
| 333 | + return statement_resp |
| 334 | + |
| 335 | + def get_token_url_response(self): |
| 336 | + token_resp = mock.Mock() |
| 337 | + token_resp.status_code = 200 |
| 338 | + |
| 339 | + # Success cases |
| 340 | + if self.url == OAUTH_SERVER_URL_1: |
| 341 | + token_resp.text = f'{{"nextUri":"{OAUTH_SERVER_URL_2}"}}' |
| 342 | + elif self.url == OAUTH_SERVER_URL_2: |
| 343 | + token_resp.text = f'{{"token":"{OAUTH_TEST_TOKEN}"}}' |
| 344 | + |
| 345 | + # Failure cases |
| 346 | + elif self.url == OAUTH_SERVER_URL_FORCE_FAIL: |
| 347 | + token_resp.status_code = 500 |
| 348 | + elif self.url == OAUTH_SERVER_URL_DENY_ACCESS: |
| 349 | + token_resp.text = OAUTH_DENY_ERROR_TEXT |
| 350 | + elif self.url == OAUTH_SERVER_URL_LOOP: |
| 351 | + token_resp.text = f'{{"nextUri":"{OAUTH_SERVER_URL_LOOP}"}}' |
| 352 | + |
| 353 | + return token_resp |
| 354 | + |
| 355 | + |
| 356 | +def call_response_hook(str, callback_func): |
| 357 | + statement_resp = mock.Mock() |
| 358 | + statement_resp.headers = requests.structures.CaseInsensitiveDict([ |
| 359 | + ('Www-Authenticate', |
| 360 | + f'Bearer x_redirect_server=\"{OAUTH_REDIRECT_SERVER}\",' |
| 361 | + f'x_token_server=\"{OAUTH_SERVER_URL_1}\",' |
| 362 | + 'Basic realm=\"Trino\"')]) |
| 363 | + statement_resp.status_code = 401 |
| 364 | + callback_func(statement_resp) |
| 365 | + |
| 366 | + |
| 367 | +@mock.patch("requests.Session.get") |
| 368 | +@mock.patch("requests.Session.post") |
| 369 | +def test_oauth2_authentication_flow(http_session_post, http_session_get, capsys): |
| 370 | + http_session = requests.Session() |
| 371 | + |
| 372 | + # set up the patched session, with the correct response |
| 373 | + oauth_test = OAuthTestReq("post", "http://coordinator/statement_url") |
| 374 | + http_session_post.return_value = oauth_test.get_statement_post_response() |
| 375 | + http_session_get.side_effect = oauth_test.get_token_url_response() |
| 376 | + oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler) |
| 377 | + |
| 378 | + statement_req = oauth_test.get_statement_request() |
| 379 | + oauth(statement_req) |
| 380 | + |
| 381 | + oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_1) |
| 382 | + token_req = oauth_test.get_token_request() |
| 383 | + oauth(token_req) |
| 384 | + |
| 385 | + oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_2) |
| 386 | + token_req = oauth_test.get_token_request() |
| 387 | + oauth(token_req) |
| 388 | + |
| 389 | + # Finally resend the original request, and respond back with status code 200 |
| 390 | + oauth_test = OAuthTestReq("post", "http://coordinator/statement_url_suceess") |
| 391 | + # statement_req.register_hook = mock.Mock(side_effect=oauth_test) |
| 392 | + statement_req = oauth_test.get_statement_request() |
| 393 | + http_session_post.return_value = oauth_test.get_statement_post_response() |
| 394 | + oauth(statement_req) |
| 395 | + |
| 396 | + out, err = capsys.readouterr() |
| 397 | + assert out == OAUTH_REDIRECT_SERVER |
| 398 | + assert statement_req.headers['Authorization'] == "Bearer " + OAUTH_TEST_TOKEN |
| 399 | + |
| 400 | + |
| 401 | +@mock.patch("requests.Session.get") |
| 402 | +@mock.patch("requests.Session.post") |
| 403 | +def test_oauth2_exceed_max_attempts(http_session_post, http_session_get): |
| 404 | + http_session = requests.Session() |
| 405 | + |
| 406 | + # set up the patched session, with the correct response |
| 407 | + oauth_test = OAuthTestReq("post", "http://coordinator/statement_url") |
| 408 | + http_session_post.return_value = oauth_test.get_statement_post_response() |
| 409 | + http_session_get.side_effect = oauth_test.get_token_url_response() |
| 410 | + oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler) |
| 411 | + |
| 412 | + statement_req = oauth_test.get_statement_request() |
| 413 | + oauth(statement_req) |
| 414 | + |
| 415 | + with pytest.raises(trino.exceptions.TrinoAuthError) as exp: |
| 416 | + for i in range(0, 5): |
| 417 | + oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_1) |
| 418 | + token_req = oauth_test.get_token_request() |
| 419 | + oauth(token_req) |
| 420 | + |
| 421 | + assert str(exp.value) == "Exceeded max attempts while getting the token" |
| 422 | + |
| 423 | + |
| 424 | +@mock.patch("requests.Session.post") |
| 425 | +def test_oauth2_authentication_missing_headers(http_session_post): |
| 426 | + http_session = requests.Session() |
| 427 | + oauth_test = OAuthTestReq("post", OAUTH_SERVER_URL_NO_HEADER) |
| 428 | + http_session_post.return_value = oauth_test.get_statement_post_response() |
| 429 | + oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler) |
| 430 | + |
| 431 | + with pytest.raises(trino.exceptions.TrinoAuthError) as exp: |
| 432 | + statement_req = oauth_test.get_statement_request() |
| 433 | + oauth(statement_req) |
| 434 | + |
| 435 | + assert str(exp.value) == "Error: header WWW-Authenticate not available in the response." |
| 436 | + |
| 437 | + |
| 438 | +@mock.patch("requests.Session.get") |
| 439 | +@mock.patch("requests.Session.post") |
| 440 | +def test_oauth2_authentication_fail_token_server(http_session_post, http_session_get): |
| 441 | + http_session = requests.Session() |
| 442 | + oauth_test = OAuthTestReq("post", OAUTH_SERVER_URL_FAIL_SERVER) |
| 443 | + http_session_post.return_value = oauth_test.get_statement_post_response() |
| 444 | + oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler) |
| 445 | + http_session_get.side_effect = oauth_test.get_token_url_response() |
| 446 | + |
| 447 | + statement_req = oauth_test.get_statement_request() |
| 448 | + oauth(statement_req) |
| 449 | + |
| 450 | + with pytest.raises(trino.exceptions.TrinoAuthError) as exp: |
| 451 | + oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_FORCE_FAIL) |
| 452 | + token_req = oauth_test.get_token_request() |
| 453 | + oauth(token_req) |
| 454 | + |
| 455 | + assert "Error while getting the token response status" in str(exp.value) |
| 456 | + |
| 457 | + |
| 458 | +@mock.patch("requests.Session.get") |
| 459 | +@mock.patch("requests.Session.post") |
| 460 | +def test_oauth2_authentication_access_denied(http_session_post, http_session_get): |
| 461 | + http_session = requests.Session() |
| 462 | + oauth_test = OAuthTestReq("post", OAUTH_SERVER_URL_SERVER_DENIED) |
| 463 | + http_session_post.return_value = oauth_test.get_statement_post_response() |
| 464 | + oauth = _OAuth2TokenBearer(http_session, oauth2_test_url_handler) |
| 465 | + http_session_get.side_effect = oauth_test.get_token_url_response() |
| 466 | + |
| 467 | + statement_req = oauth_test.get_statement_request() |
| 468 | + oauth(statement_req) |
| 469 | + |
| 470 | + with pytest.raises(trino.exceptions.TrinoAuthError) as exp: |
| 471 | + oauth_test = OAuthTestReq("get", OAUTH_SERVER_URL_FORCE_FAIL) |
| 472 | + token_req = oauth_test.get_token_request() |
| 473 | + oauth(token_req) |
| 474 | + |
| 475 | + assert "Error while getting the token" in str(exp.value) |
| 476 | + |
| 477 | + |
257 | 478 | @mock.patch("trino.client.TrinoRequest.http") |
258 | 479 | def test_trino_fetch_request(mock_requests, sample_get_response_data): |
259 | 480 | mock_requests.Response.return_value.json.return_value = sample_get_response_data |
|
0 commit comments