Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit bd4919f

Browse files
authored
MSC2918 Refresh tokens implementation (#9450)
This implements refresh tokens, as defined by MSC2918 This MSC has been implemented client side in Hydrogen Web: element-hq/hydrogen-web#235 The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one. Signed-off-by: Quentin Gliech <[email protected]>
1 parent 763dba7 commit bd4919f

File tree

15 files changed

+892
-61
lines changed

15 files changed

+892
-61
lines changed

changelog.d/9450.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implement refresh tokens as specified by [MSC2918](https://github.com/matrix-org/matrix-doc/pull/2918).

scripts/synapse_port_db

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ BOOLEAN_COLUMNS = {
9393
"local_media_repository": ["safe_from_quarantine"],
9494
"users": ["shadow_banned"],
9595
"e2e_fallback_keys_json": ["used"],
96+
"access_tokens": ["used"],
9697
}
9798

9899

@@ -307,7 +308,8 @@ class Porter(object):
307308
information_schema.table_constraints AS tc
308309
INNER JOIN information_schema.constraint_column_usage AS ccu
309310
USING (table_schema, constraint_name)
310-
WHERE tc.constraint_type = 'FOREIGN KEY';
311+
WHERE tc.constraint_type = 'FOREIGN KEY'
312+
AND tc.table_name != ccu.table_name;
311313
"""
312314
txn.execute(sql)
313315

synapse/api/auth.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@ async def get_user_by_req(
245245
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
246246
)
247247

248+
# Mark the token as used. This is used to invalidate old refresh
249+
# tokens after some time.
250+
if not user_info.token_used and token_id is not None:
251+
await self.store.mark_access_token_as_used(token_id)
252+
248253
requester = create_requester(
249254
user_info.user_id,
250255
token_id,

synapse/config/registration.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,27 @@ def read_config(self, config, **kwargs):
119119
session_lifetime = self.parse_duration(session_lifetime)
120120
self.session_lifetime = session_lifetime
121121

122+
# The `access_token_lifetime` applies for tokens that can be renewed
123+
# using a refresh token, as per MSC2918. If it is `None`, the refresh
124+
# token mechanism is disabled.
125+
#
126+
# Since it is incompatible with the `session_lifetime` mechanism, it is set to
127+
# `None` by default if a `session_lifetime` is set.
128+
access_token_lifetime = config.get(
129+
"access_token_lifetime", "5m" if session_lifetime is None else None
130+
)
131+
if access_token_lifetime is not None:
132+
access_token_lifetime = self.parse_duration(access_token_lifetime)
133+
self.access_token_lifetime = access_token_lifetime
134+
135+
if session_lifetime is not None and access_token_lifetime is not None:
136+
raise ConfigError(
137+
"The refresh token mechanism is incompatible with the "
138+
"`session_lifetime` option. Consider disabling the "
139+
"`session_lifetime` option or disabling the refresh token "
140+
"mechanism by removing the `access_token_lifetime` option."
141+
)
142+
122143
# The success template used during fallback auth.
123144
self.fallback_success_template = self.read_template("auth_success.html")
124145

synapse/handlers/auth.py

Lines changed: 127 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Optional,
3131
Tuple,
3232
Union,
33+
cast,
3334
)
3435

3536
import attr
@@ -72,6 +73,7 @@
7273
from synapse.util.threepids import canonicalise_email
7374

7475
if TYPE_CHECKING:
76+
from synapse.rest.client.v1.login import LoginResponse
7577
from synapse.server import HomeServer
7678

7779
logger = logging.getLogger(__name__)
@@ -777,13 +779,116 @@ def _auth_dict_for_flows(
777779
"params": params,
778780
}
779781

782+
async def refresh_token(
783+
self,
784+
refresh_token: str,
785+
valid_until_ms: Optional[int],
786+
) -> Tuple[str, str]:
787+
"""
788+
Consumes a refresh token and generate both a new access token and a new refresh token from it.
789+
790+
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
791+
792+
Args:
793+
refresh_token: The token to consume.
794+
valid_until_ms: The expiration timestamp of the new access token.
795+
796+
Returns:
797+
A tuple containing the new access token and refresh token
798+
"""
799+
800+
# Verify the token signature first before looking up the token
801+
if not self._verify_refresh_token(refresh_token):
802+
raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
803+
804+
existing_token = await self.store.lookup_refresh_token(refresh_token)
805+
if existing_token is None:
806+
raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
807+
808+
if (
809+
existing_token.has_next_access_token_been_used
810+
or existing_token.has_next_refresh_token_been_refreshed
811+
):
812+
raise SynapseError(
813+
403, "refresh token isn't valid anymore", Codes.FORBIDDEN
814+
)
815+
816+
(
817+
new_refresh_token,
818+
new_refresh_token_id,
819+
) = await self.get_refresh_token_for_user_id(
820+
user_id=existing_token.user_id, device_id=existing_token.device_id
821+
)
822+
access_token = await self.get_access_token_for_user_id(
823+
user_id=existing_token.user_id,
824+
device_id=existing_token.device_id,
825+
valid_until_ms=valid_until_ms,
826+
refresh_token_id=new_refresh_token_id,
827+
)
828+
await self.store.replace_refresh_token(
829+
existing_token.token_id, new_refresh_token_id
830+
)
831+
return access_token, new_refresh_token
832+
833+
def _verify_refresh_token(self, token: str) -> bool:
834+
"""
835+
Verifies the shape of a refresh token.
836+
837+
Args:
838+
token: The refresh token to verify
839+
840+
Returns:
841+
Whether the token has the right shape
842+
"""
843+
parts = token.split("_", maxsplit=4)
844+
if len(parts) != 4:
845+
return False
846+
847+
type, localpart, rand, crc = parts
848+
849+
# Refresh tokens are prefixed by "syr_", let's check that
850+
if type != "syr":
851+
return False
852+
853+
# Check the CRC
854+
base = f"{type}_{localpart}_{rand}"
855+
expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
856+
if crc != expected_crc:
857+
return False
858+
859+
return True
860+
861+
async def get_refresh_token_for_user_id(
862+
self,
863+
user_id: str,
864+
device_id: str,
865+
) -> Tuple[str, int]:
866+
"""
867+
Creates a new refresh token for the user with the given user ID.
868+
869+
Args:
870+
user_id: canonical user ID
871+
device_id: the device ID to associate with the token.
872+
873+
Returns:
874+
The newly created refresh token and its ID in the database
875+
"""
876+
refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
877+
refresh_token_id = await self.store.add_refresh_token_to_user(
878+
user_id=user_id,
879+
token=refresh_token,
880+
device_id=device_id,
881+
)
882+
return refresh_token, refresh_token_id
883+
780884
async def get_access_token_for_user_id(
781885
self,
782886
user_id: str,
783887
device_id: Optional[str],
784888
valid_until_ms: Optional[int],
785889
puppets_user_id: Optional[str] = None,
786890
is_appservice_ghost: bool = False,
891+
refresh_token_id: Optional[int] = None,
787892
) -> str:
788893
"""
789894
Creates a new access token for the user with the given user ID.
@@ -801,6 +906,8 @@ async def get_access_token_for_user_id(
801906
valid_until_ms: when the token is valid until. None for
802907
no expiry.
803908
is_appservice_ghost: Whether the user is an application ghost user
909+
refresh_token_id: the refresh token ID that will be associated with
910+
this access token.
804911
Returns:
805912
The access token for the user's session.
806913
Raises:
@@ -836,6 +943,7 @@ async def get_access_token_for_user_id(
836943
device_id=device_id,
837944
valid_until_ms=valid_until_ms,
838945
puppets_user_id=puppets_user_id,
946+
refresh_token_id=refresh_token_id,
839947
)
840948

841949
# the device *should* have been registered before we got here; however,
@@ -928,7 +1036,7 @@ async def validate_login(
9281036
self,
9291037
login_submission: Dict[str, Any],
9301038
ratelimit: bool = False,
931-
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
1039+
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
9321040
"""Authenticates the user for the /login API
9331041
9341042
Also used by the user-interactive auth flow to validate auth types which don't
@@ -1073,7 +1181,7 @@ async def _validate_userid_login(
10731181
self,
10741182
username: str,
10751183
login_submission: Dict[str, Any],
1076-
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
1184+
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
10771185
"""Helper for validate_login
10781186
10791187
Handles login, once we've mapped 3pids onto userids
@@ -1151,7 +1259,7 @@ async def _validate_userid_login(
11511259

11521260
async def check_password_provider_3pid(
11531261
self, medium: str, address: str, password: str
1154-
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
1262+
) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
11551263
"""Check if a password provider is able to validate a thirdparty login
11561264
11571265
Args:
@@ -1215,6 +1323,19 @@ def generate_access_token(self, for_user: UserID) -> str:
12151323
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
12161324
return f"{base}_{crc}"
12171325

1326+
def generate_refresh_token(self, for_user: UserID) -> str:
1327+
"""Generates an opaque string, for use as a refresh token"""
1328+
1329+
# we use the following format for refresh tokens:
1330+
# syr_<base64 local part>_<random string>_<base62 crc check>
1331+
1332+
b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
1333+
random_string = stringutils.random_string(20)
1334+
base = f"syr_{b64local}_{random_string}"
1335+
1336+
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
1337+
return f"{base}_{crc}"
1338+
12181339
async def validate_short_term_login_token(
12191340
self, login_token: str
12201341
) -> LoginTokenAttributes:
@@ -1563,7 +1684,7 @@ def _complete_sso_login(
15631684
)
15641685
respond_with_html(request, 200, html)
15651686

1566-
async def _sso_login_callback(self, login_result: JsonDict) -> None:
1687+
async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
15671688
"""
15681689
A login callback which might add additional attributes to the login response.
15691690
@@ -1577,7 +1698,8 @@ async def _sso_login_callback(self, login_result: JsonDict) -> None:
15771698

15781699
extra_attributes = self._extra_attributes.get(login_result["user_id"])
15791700
if extra_attributes:
1580-
login_result.update(extra_attributes.extra_attributes)
1701+
login_result_dict = cast(Dict[str, Any], login_result)
1702+
login_result_dict.update(extra_attributes.extra_attributes)
15811703

15821704
def _expire_sso_extra_attributes(self) -> None:
15831705
"""

0 commit comments

Comments
 (0)