Skip to content

Commit 57cc96e

Browse files
authored
feat: Add optional non blocking refresh for sync auth code (#1368)
feat: Add optional non blocking refresh for sync auth code
1 parent 0cb62ef commit 57cc96e

File tree

12 files changed

+485
-8
lines changed

12 files changed

+485
-8
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import logging
17+
import threading
18+
19+
import google.auth.exceptions as e
20+
21+
_LOGGER = logging.getLogger(__name__)
22+
23+
24+
class RefreshThreadManager:
25+
"""
26+
Organizes exactly one background job that refresh a token.
27+
"""
28+
29+
def __init__(self):
30+
"""Initializes the manager."""
31+
32+
self._worker = None
33+
self._lock = threading.Lock() # protects access to worker threads.
34+
35+
def start_refresh(self, cred, request):
36+
"""Starts a refresh thread for the given credentials.
37+
The credentials are refreshed using the request parameter.
38+
request and cred MUST not be None
39+
40+
Returns True if a background refresh was kicked off. False otherwise.
41+
42+
Args:
43+
cred: A credentials object.
44+
request: A request object.
45+
Returns:
46+
bool
47+
"""
48+
if cred is None or request is None:
49+
raise e.InvalidValue(
50+
"Unable to start refresh. cred and request must be valid and instantiated objects."
51+
)
52+
53+
with self._lock:
54+
if self._worker is not None and self._worker._error_info is not None:
55+
return False
56+
57+
if self._worker is None or not self._worker.is_alive(): # pragma: NO COVER
58+
self._worker = RefreshThread(cred=cred, request=copy.deepcopy(request))
59+
self._worker.start()
60+
return True
61+
62+
def clear_error(self):
63+
"""
64+
Removes any errors that were stored from previous background refreshes.
65+
"""
66+
with self._lock:
67+
if self._worker:
68+
self._worker._error_info = None
69+
70+
71+
class RefreshThread(threading.Thread):
72+
"""
73+
Thread that refreshes credentials.
74+
"""
75+
76+
def __init__(self, cred, request, **kwargs):
77+
"""Initializes the thread.
78+
79+
Args:
80+
cred: A Credential object to refresh.
81+
request: A Request object used to perform a credential refresh.
82+
**kwargs: Additional keyword arguments.
83+
"""
84+
85+
super().__init__(**kwargs)
86+
self._cred = cred
87+
self._request = request
88+
self._error_info = None
89+
90+
def run(self):
91+
"""
92+
Perform the credential refresh.
93+
"""
94+
try:
95+
self._cred.refresh(self._request)
96+
except Exception as err: # pragma: NO COVER
97+
_LOGGER.error(f"Background refresh failed due to: {err}")
98+
self._error_info = err

packages/google-auth/google/auth/credentials.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
"""Interfaces for credentials."""
1717

1818
import abc
19+
from enum import Enum
1920
import os
2021

2122
from google.auth import _helpers, environment_vars
2223
from google.auth import exceptions
2324
from google.auth import metrics
25+
from google.auth._refresh_worker import RefreshThreadManager
2426

2527

2628
class Credentials(metaclass=abc.ABCMeta):
@@ -59,17 +61,22 @@ def __init__(self):
5961
"""Optional[str]: The universe domain value, default is googleapis.com
6062
"""
6163

64+
self._use_non_blocking_refresh = False
65+
self._refresh_worker = RefreshThreadManager()
66+
6267
@property
6368
def expired(self):
6469
"""Checks if the credentials are expired.
6570
6671
Note that credentials can be invalid but not expired because
6772
Credentials with :attr:`expiry` set to None is considered to never
6873
expire.
74+
75+
.. deprecated:: v2.24.0
76+
Prefer checking :attr:`token_state` instead.
6977
"""
7078
if not self.expiry:
7179
return False
72-
7380
# Remove some threshold from expiry to err on the side of reporting
7481
# expiration early so that we avoid the 401-refresh-retry loop.
7582
skewed_expiry = self.expiry - _helpers.REFRESH_THRESHOLD
@@ -81,9 +88,34 @@ def valid(self):
8188
8289
This is True if the credentials have a :attr:`token` and the token
8390
is not :attr:`expired`.
91+
92+
.. deprecated:: v2.24.0
93+
Prefer checking :attr:`token_state` instead.
8494
"""
8595
return self.token is not None and not self.expired
8696

97+
@property
98+
def token_state(self):
99+
"""
100+
See `:obj:`TokenState`
101+
"""
102+
if self.token is None:
103+
return TokenState.INVALID
104+
105+
# Credentials that can't expire are always treated as fresh.
106+
if self.expiry is None:
107+
return TokenState.FRESH
108+
109+
expired = _helpers.utcnow() >= self.expiry
110+
if expired:
111+
return TokenState.INVALID
112+
113+
is_stale = _helpers.utcnow() >= (self.expiry - _helpers.REFRESH_THRESHOLD)
114+
if is_stale:
115+
return TokenState.STALE
116+
117+
return TokenState.FRESH
118+
87119
@property
88120
def quota_project_id(self):
89121
"""Project to use for quota and billing purposes."""
@@ -154,6 +186,25 @@ def apply(self, headers, token=None):
154186
if self.quota_project_id:
155187
headers["x-goog-user-project"] = self.quota_project_id
156188

189+
def _blocking_refresh(self, request):
190+
if not self.valid:
191+
self.refresh(request)
192+
193+
def _non_blocking_refresh(self, request):
194+
use_blocking_refresh_fallback = False
195+
196+
if self.token_state == TokenState.STALE:
197+
use_blocking_refresh_fallback = not self._refresh_worker.start_refresh(
198+
self, request
199+
)
200+
201+
if self.token_state == TokenState.INVALID or use_blocking_refresh_fallback:
202+
self.refresh(request)
203+
# If the blocking refresh succeeds then we can clear the error info
204+
# on the background refresh worker, and perform refreshes in a
205+
# background thread.
206+
self._refresh_worker.clear_error()
207+
157208
def before_request(self, request, method, url, headers):
158209
"""Performs credential-specific before request logic.
159210
@@ -171,11 +222,17 @@ def before_request(self, request, method, url, headers):
171222
# pylint: disable=unused-argument
172223
# (Subclasses may use these arguments to ascertain information about
173224
# the http request.)
174-
if not self.valid:
175-
self.refresh(request)
225+
if self._use_non_blocking_refresh:
226+
self._non_blocking_refresh(request)
227+
else:
228+
self._blocking_refresh(request)
229+
176230
metrics.add_metric_header(headers, self._metric_header_for_usage())
177231
self.apply(headers)
178232

233+
def with_non_blocking_refresh(self):
234+
self._use_non_blocking_refresh = True
235+
179236

180237
class CredentialsWithQuotaProject(Credentials):
181238
"""Abstract base for credentials supporting ``with_quota_project`` factory"""
@@ -439,3 +496,16 @@ def signer(self):
439496
# pylint: disable=missing-raises-doc
440497
# (pylint doesn't recognize that this is abstract)
441498
raise NotImplementedError("Signer must be implemented.")
499+
500+
501+
class TokenState(Enum):
502+
"""
503+
Tracks the state of a token.
504+
FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry.
505+
STALE: The token is close to expired, and should be refreshed. The token can be used normally.
506+
INVALID: The token is expired or invalid. The token cannot be used for a normal operation.
507+
"""
508+
509+
FRESH = 1
510+
STALE = 2
511+
INVALID = 3

packages/google-auth/google/auth/impersonated_credentials.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,10 @@ def _update_token(self, request):
259259
"""
260260

261261
# Refresh our source credentials if it is not valid.
262-
if not self._source_credentials.valid:
262+
if (
263+
self._source_credentials.token_state == credentials.TokenState.STALE
264+
or self._source_credentials.token_state == credentials.TokenState.INVALID
265+
):
263266
self._source_credentials.refresh(request)
264267

265268
body = {

packages/google-auth/google/oauth2/credentials.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def __getstate__(self):
161161
# because they need to be importable.
162162
# Instead, the refresh_handler setter should be used to repopulate this.
163163
del state_dict["_refresh_handler"]
164+
# Remove worker as it contains multiproccessing queue objects.
165+
del state_dict["_refresh_worker"]
164166
return state_dict
165167

166168
def __setstate__(self, d):
@@ -183,6 +185,8 @@ def __setstate__(self, d):
183185
self._universe_domain = d.get("_universe_domain") or _DEFAULT_UNIVERSE_DOMAIN
184186
# The refresh_handler setter should be used to repopulate this.
185187
self._refresh_handler = None
188+
self._refresh_worker = None
189+
self._use_non_blocking_refresh = d.get("_use_non_blocking_refresh")
186190

187191
@property
188192
def refresh_token(self):
0 Bytes
Binary file not shown.

packages/google-auth/tests/oauth2/test_credentials.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from google.auth import _helpers
2525
from google.auth import exceptions
2626
from google.auth import transport
27+
from google.auth.credentials import TokenState
2728
from google.oauth2 import credentials
2829

2930

@@ -61,6 +62,7 @@ def test_default_state(self):
6162
assert not credentials.expired
6263
# Scopes aren't required for these credentials
6364
assert not credentials.requires_scopes
65+
assert credentials.token_state == TokenState.INVALID
6466
# Test properties
6567
assert credentials.refresh_token == self.REFRESH_TOKEN
6668
assert credentials.token_uri == self.TOKEN_URI
@@ -911,7 +913,11 @@ def test_pickle_and_unpickle(self):
911913
assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()
912914

913915
for attr in list(creds.__dict__):
914-
assert getattr(creds, attr) == getattr(unpickled, attr)
916+
# Worker should always be None
917+
if attr == "_refresh_worker":
918+
assert getattr(unpickled, attr) is None
919+
else:
920+
assert getattr(creds, attr) == getattr(unpickled, attr)
915921

916922
def test_pickle_and_unpickle_universe_domain(self):
917923
# old version of auth lib doesn't have _universe_domain, so the pickled
@@ -945,7 +951,7 @@ def test_pickle_and_unpickle_with_refresh_handler(self):
945951
for attr in list(creds.__dict__):
946952
# For the _refresh_handler property, the unpickled creds should be
947953
# set to None.
948-
if attr == "_refresh_handler":
954+
if attr == "_refresh_handler" or attr == "_refresh_worker":
949955
assert getattr(unpickled, attr) is None
950956
else:
951957
assert getattr(creds, attr) == getattr(unpickled, attr)

0 commit comments

Comments
 (0)