Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/18231.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add an access token introspection cache to make Matrix Authentication Service integration (MSC3861) more efficient.
113 changes: 96 additions & 17 deletions synapse/api/auth/msc3861_delegated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#
#
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from urllib.parse import urlencode

Expand Down Expand Up @@ -47,6 +48,7 @@
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.response_cache import ResponseCache

if TYPE_CHECKING:
from synapse.rest.admin.experimental_features import ExperimentalFeature
Expand Down Expand Up @@ -76,6 +78,61 @@ def scope_to_list(scope: str) -> List[str]:
return scope.strip().split(" ")


@dataclass
class IntrospectionResult:
_inner: IntrospectionToken

# when we retrieved this token,
# in milliseconds since the Unix epoch
retrieved_at_ms: int

def is_active(self, now_ms: int) -> bool:
if not self._inner.get("active"):
return False

expires_in = self._inner.get("expires_in")
if expires_in is None:
return True
if not isinstance(expires_in, int):
raise InvalidClientTokenError("token `expires_in` is not an int")

absolute_expiry_ms = expires_in * 1000 + self.retrieved_at_ms
return now_ms < absolute_expiry_ms

def get_scope_list(self) -> List[str]:
value = self._inner.get("scope")
if not isinstance(value, str):
return []
return scope_to_list(value)

def get_sub(self) -> Optional[str]:
value = self._inner.get("sub")
if not isinstance(value, str):
return None
return value

def get_username(self) -> Optional[str]:
value = self._inner.get("username")
if not isinstance(value, str):
return None
return value

def get_name(self) -> Optional[str]:
value = self._inner.get("name")
if not isinstance(value, str):
return None
return value

def get_device_id(self) -> Optional[str]:
value = self._inner.get("device_id")
if value is not None and not isinstance(value, str):
raise AuthError(
500,
"Invalid device ID in introspection result",
)
return value


class PrivateKeyJWTWithKid(PrivateKeyJWT): # type: ignore[misc]
"""An implementation of the private_key_jwt client auth method that includes a kid header.

Expand Down Expand Up @@ -121,6 +178,31 @@ def __init__(self, hs: "HomeServer"):
self._hostname = hs.hostname
self._admin_token: Callable[[], Optional[str]] = self._config.admin_token

# # Token Introspection Cache
# This remembers what users/devices are represented by which access tokens,
# in order to reduce overall system load:
# - on Synapse (as requests are relatively expensive)
# - on the network
# - on MAS
#
# Since there is no invalidation mechanism currently,
# the entries expire after 2 minutes.
# This does mean tokens can be treated as valid by Synapse
# for longer than reality.
#
# Ideally, tokens should logically be invalidated in the following circumstances:
# - If a session logout happens.
# In this case, MAS will delete the device within Synapse
# anyway and this is good enough as an invalidation.
# - If the client refreshes their token in MAS.
# In this case, the device still exists and it's not the end of the world for
# the old access token to continue working for a short time.
self._introspection_cache: ResponseCache[str] = ResponseCache(
self._clock,
"token_introspection",
timeout_ms=120_000,
)

self._issuer_metadata = RetryOnExceptionCachedCall[OpenIDProviderMetadata](
self._load_metadata
)
Expand Down Expand Up @@ -193,7 +275,7 @@ async def _introspection_endpoint(self) -> str:
metadata = await self._issuer_metadata.get()
return metadata.get("introspection_endpoint")

async def _introspect_token(self, token: str) -> IntrospectionToken:
async def _introspect_token(self, token: str) -> IntrospectionResult:
"""
Send a token to the introspection endpoint and returns the introspection response

Expand Down Expand Up @@ -266,7 +348,9 @@ async def _introspect_token(self, token: str) -> IntrospectionToken:
"The introspection endpoint returned an invalid JSON response."
)

return IntrospectionToken(**resp)
return IntrospectionResult(
IntrospectionToken(**resp), retrieved_at_ms=self._clock.time_msec()
)

async def is_server_admin(self, requester: Requester) -> bool:
return "urn:synapse:admin:*" in requester.scope
Expand Down Expand Up @@ -344,7 +428,9 @@ async def get_user_by_access_token(
)

try:
introspection_result = await self._introspect_token(token)
introspection_result = await self._introspection_cache.wrap(
token, self._introspect_token, token
)
except Exception:
logger.exception("Failed to introspect token")
raise SynapseError(503, "Unable to introspect the access token")
Expand All @@ -353,11 +439,11 @@ async def get_user_by_access_token(

# TODO: introspection verification should be more extensive, especially:
# - verify the audience
if not introspection_result.get("active"):
if not introspection_result.is_active(self._clock.time_msec()):
raise InvalidClientTokenError("Token is not active")

# Let's look at the scope
scope: List[str] = scope_to_list(introspection_result.get("scope", ""))
scope: List[str] = introspection_result.get_scope_list()

# Determine type of user based on presence of particular scopes
has_user_scope = SCOPE_MATRIX_API in scope
Expand All @@ -367,7 +453,7 @@ async def get_user_by_access_token(
raise InvalidClientTokenError("No scope in token granting user rights")

# Match via the sub claim
sub: Optional[str] = introspection_result.get("sub")
sub: Optional[str] = introspection_result.get_sub()
if sub is None:
raise InvalidClientTokenError(
"Invalid sub claim in the introspection result"
Expand All @@ -381,7 +467,7 @@ async def get_user_by_access_token(
# or the external_id was never recorded

# TODO: claim mapping should be configurable
username: Optional[str] = introspection_result.get("username")
username: Optional[str] = introspection_result.get_username()
if username is None or not isinstance(username, str):
raise AuthError(
500,
Expand All @@ -399,7 +485,7 @@ async def get_user_by_access_token(

# TODO: claim mapping should be configurable
# If present, use the name claim as the displayname
name: Optional[str] = introspection_result.get("name")
name: Optional[str] = introspection_result.get_name()

await self.store.register_user(
user_id=user_id.to_string(), create_profile_with_displayname=name
Expand All @@ -414,15 +500,8 @@ async def get_user_by_access_token(

# MAS 0.15+ will give us the device ID as an explicit value for compatibility sessions
# If present, we get it from here, if not we get it in thee scope
device_id = introspection_result.get("device_id")
if device_id is not None:
# We got the device ID explicitly, just sanity check that it's a string
if not isinstance(device_id, str):
raise AuthError(
500,
"Invalid device ID in introspection result",
)
else:
device_id = introspection_result.get_device_id()
if device_id is None:
# Find device_ids in scope
# We only allow a single device_id in the scope, so we find them all in the
# scope list, and raise if there are more than one. The OIDC server should be
Expand Down
38 changes: 38 additions & 0 deletions tests/handlers/test_oauth_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,44 @@ def test_unavailable_introspection_endpoint(self) -> None:
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)

def test_cached_expired_introspection(self) -> None:
"""The handler should raise an error if the introspection response gives
an expiry time, the introspection response is cached and then the entry is
re-requested after it has expired."""

self.http_client.request = introspection_mock = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join(
[
MATRIX_USER_SCOPE,
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
]
),
"username": USERNAME,
"expires_in": 60,
},
)
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()

# The first CS-API request causes a successful introspection
self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(introspection_mock.call_count, 1)

# Sleep for 60 seconds so the token expires.
self.reactor.advance(60.0)

# Now the CS-API request fails because the token expired
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
# Ensure another introspection request was not sent
self.assertEqual(introspection_mock.call_count, 1)

def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)
Expand Down
Loading