Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 8839b6c

Browse files
authored
Add requesting user id parameter to key claim methods in TransportLayerClient (#15663)
1 parent ca5c4be commit 8839b6c

File tree

6 files changed

+39
-11
lines changed

6 files changed

+39
-11
lines changed

changelog.d/15663.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add requesting user id parameter to key claim methods in `TransportLayerClient`.

synapse/federation/federation_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,15 @@ async def query_user_devices(
236236

237237
async def claim_client_keys(
238238
self,
239+
user: UserID,
239240
destination: str,
240241
query: Dict[str, Dict[str, Dict[str, int]]],
241242
timeout: Optional[int],
242243
) -> JsonDict:
243244
"""Claims one-time keys for a device hosted on a remote server.
244245
245246
Args:
247+
user: The user id of the requesting user
246248
destination: Domain name of the remote homeserver
247249
content: The query content.
248250
@@ -279,7 +281,7 @@ async def claim_client_keys(
279281
if use_unstable:
280282
try:
281283
return await self.transport_layer.claim_client_keys_unstable(
282-
destination, unstable_content, timeout
284+
user, destination, unstable_content, timeout
283285
)
284286
except HttpResponseException as e:
285287
# If an error is received that is due to an unrecognised endpoint,
@@ -295,7 +297,7 @@ async def claim_client_keys(
295297
logger.debug("Skipping unstable claim client keys API")
296298

297299
return await self.transport_layer.claim_client_keys(
298-
destination, content, timeout
300+
user, destination, content, timeout
299301
)
300302

301303
@trace

synapse/federation/transport/client.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from synapse.federation.units import Transaction
4646
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
4747
from synapse.http.types import QueryParams
48-
from synapse.types import JsonDict
48+
from synapse.types import JsonDict, UserID
4949
from synapse.util import ExceptionBundle
5050

5151
if TYPE_CHECKING:
@@ -630,7 +630,11 @@ async def query_user_devices(
630630
)
631631

632632
async def claim_client_keys(
633-
self, destination: str, query_content: JsonDict, timeout: Optional[int]
633+
self,
634+
user: UserID,
635+
destination: str,
636+
query_content: JsonDict,
637+
timeout: Optional[int],
634638
) -> JsonDict:
635639
"""Claim one-time keys for a list of devices hosted on a remote server.
636640
@@ -655,6 +659,7 @@ async def claim_client_keys(
655659
}
656660
657661
Args:
662+
user: the user_id of the requesting user
658663
destination: The server to query.
659664
query_content: The user ids to query.
660665
Returns:
@@ -671,7 +676,11 @@ async def claim_client_keys(
671676
)
672677

673678
async def claim_client_keys_unstable(
674-
self, destination: str, query_content: JsonDict, timeout: Optional[int]
679+
self,
680+
user: UserID,
681+
destination: str,
682+
query_content: JsonDict,
683+
timeout: Optional[int],
675684
) -> JsonDict:
676685
"""Claim one-time keys for a list of devices hosted on a remote server.
677686
@@ -696,6 +705,7 @@ async def claim_client_keys_unstable(
696705
}
697706
698707
Args:
708+
user: the user_id of the requesting user
699709
destination: The server to query.
700710
query_content: The user ids to query.
701711
Returns:

synapse/handlers/e2e_keys.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ async def claim_local_one_time_keys(
661661
async def claim_one_time_keys(
662662
self,
663663
query: Dict[str, Dict[str, Dict[str, int]]],
664+
user: UserID,
664665
timeout: Optional[int],
665666
always_include_fallback_keys: bool,
666667
) -> JsonDict:
@@ -703,7 +704,7 @@ async def claim_client_keys(destination: str) -> None:
703704
device_keys = remote_queries[destination]
704705
try:
705706
remote_result = await self.federation.claim_client_keys(
706-
destination, device_keys, timeout=timeout
707+
user, destination, device_keys, timeout=timeout
707708
)
708709
for user_id, keys in remote_result["one_time_keys"].items():
709710
if user_id in device_keys:

synapse/rest/client/keys.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def __init__(self, hs: "HomeServer"):
287287
self.e2e_keys_handler = hs.get_e2e_keys_handler()
288288

289289
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
290-
await self.auth.get_user_by_req(request, allow_guest=True)
290+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
291291
timeout = parse_integer(request, "timeout", 10 * 1000)
292292
body = parse_json_object_from_request(request)
293293

@@ -298,7 +298,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
298298
query.setdefault(user_id, {})[device_id] = {algorithm: 1}
299299

300300
result = await self.e2e_keys_handler.claim_one_time_keys(
301-
query, timeout, always_include_fallback_keys=False
301+
query, requester.user, timeout, always_include_fallback_keys=False
302302
)
303303
return 200, result
304304

@@ -335,7 +335,7 @@ def __init__(self, hs: "HomeServer"):
335335
self.e2e_keys_handler = hs.get_e2e_keys_handler()
336336

337337
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
338-
await self.auth.get_user_by_req(request, allow_guest=True)
338+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
339339
timeout = parse_integer(request, "timeout", 10 * 1000)
340340
body = parse_json_object_from_request(request)
341341

@@ -346,7 +346,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
346346
query.setdefault(user_id, {})[device_id] = Counter(algorithms)
347347

348348
result = await self.e2e_keys_handler.claim_one_time_keys(
349-
query, timeout, always_include_fallback_keys=True
349+
query, requester.user, timeout, always_include_fallback_keys=True
350350
)
351351
return 200, result
352352

tests/handlers/test_e2e_keys.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from synapse.handlers.device import DeviceHandler
2828
from synapse.server import HomeServer
2929
from synapse.storage.databases.main.appservice import _make_exclusive_regex
30-
from synapse.types import JsonDict
30+
from synapse.types import JsonDict, UserID
3131
from synapse.util import Clock
3232

3333
from tests import unittest
@@ -45,6 +45,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
4545
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
4646
self.handler = hs.get_e2e_keys_handler()
4747
self.store = self.hs.get_datastores().main
48+
self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}")
4849

4950
def test_query_local_devices_no_devices(self) -> None:
5051
"""If the user has no devices, we expect an empty list."""
@@ -161,6 +162,7 @@ def test_claim_one_time_key(self) -> None:
161162
res2 = self.get_success(
162163
self.handler.claim_one_time_keys(
163164
{local_user: {device_id: {"alg1": 1}}},
165+
self.requester,
164166
timeout=None,
165167
always_include_fallback_keys=False,
166168
)
@@ -206,6 +208,7 @@ def test_fallback_key(self) -> None:
206208
claim_res = self.get_success(
207209
self.handler.claim_one_time_keys(
208210
{local_user: {device_id: {"alg1": 1}}},
211+
self.requester,
209212
timeout=None,
210213
always_include_fallback_keys=False,
211214
)
@@ -225,6 +228,7 @@ def test_fallback_key(self) -> None:
225228
claim_res = self.get_success(
226229
self.handler.claim_one_time_keys(
227230
{local_user: {device_id: {"alg1": 1}}},
231+
self.requester,
228232
timeout=None,
229233
always_include_fallback_keys=False,
230234
)
@@ -274,6 +278,7 @@ def test_fallback_key(self) -> None:
274278
claim_res = self.get_success(
275279
self.handler.claim_one_time_keys(
276280
{local_user: {device_id: {"alg1": 1}}},
281+
self.requester,
277282
timeout=None,
278283
always_include_fallback_keys=False,
279284
)
@@ -286,6 +291,7 @@ def test_fallback_key(self) -> None:
286291
claim_res = self.get_success(
287292
self.handler.claim_one_time_keys(
288293
{local_user: {device_id: {"alg1": 1}}},
294+
self.requester,
289295
timeout=None,
290296
always_include_fallback_keys=False,
291297
)
@@ -307,6 +313,7 @@ def test_fallback_key(self) -> None:
307313
claim_res = self.get_success(
308314
self.handler.claim_one_time_keys(
309315
{local_user: {device_id: {"alg1": 1}}},
316+
self.requester,
310317
timeout=None,
311318
always_include_fallback_keys=False,
312319
)
@@ -348,6 +355,7 @@ def test_fallback_key_always_returned(self) -> None:
348355
claim_res = self.get_success(
349356
self.handler.claim_one_time_keys(
350357
{local_user: {device_id: {"alg1": 1}}},
358+
self.requester,
351359
timeout=None,
352360
always_include_fallback_keys=True,
353361
)
@@ -370,6 +378,7 @@ def test_fallback_key_always_returned(self) -> None:
370378
claim_res = self.get_success(
371379
self.handler.claim_one_time_keys(
372380
{local_user: {device_id: {"alg1": 1}}},
381+
self.requester,
373382
timeout=None,
374383
always_include_fallback_keys=True,
375384
)
@@ -1080,6 +1089,7 @@ def test_query_appservice(self) -> None:
10801089
claim_res = self.get_success(
10811090
self.handler.claim_one_time_keys(
10821091
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
1092+
self.requester,
10831093
timeout=None,
10841094
always_include_fallback_keys=False,
10851095
)
@@ -1125,6 +1135,7 @@ def test_query_appservice_with_fallback(self) -> None:
11251135
claim_res = self.get_success(
11261136
self.handler.claim_one_time_keys(
11271137
{local_user: {device_id_1: {"alg1": 1}}},
1138+
self.requester,
11281139
timeout=None,
11291140
always_include_fallback_keys=True,
11301141
)
@@ -1169,6 +1180,7 @@ def test_query_appservice_with_fallback(self) -> None:
11691180
claim_res = self.get_success(
11701181
self.handler.claim_one_time_keys(
11711182
{local_user: {device_id_1: {"alg1": 1}}},
1183+
self.requester,
11721184
timeout=None,
11731185
always_include_fallback_keys=True,
11741186
)
@@ -1202,6 +1214,7 @@ def test_query_appservice_with_fallback(self) -> None:
12021214
claim_res = self.get_success(
12031215
self.handler.claim_one_time_keys(
12041216
{local_user: {device_id_1: {"alg1": 1}}},
1217+
self.requester,
12051218
timeout=None,
12061219
always_include_fallback_keys=True,
12071220
)
@@ -1229,6 +1242,7 @@ def test_query_appservice_with_fallback(self) -> None:
12291242
claim_res = self.get_success(
12301243
self.handler.claim_one_time_keys(
12311244
{local_user: {device_id_1: {"alg1": 1}}},
1245+
self.requester,
12321246
timeout=None,
12331247
always_include_fallback_keys=True,
12341248
)

0 commit comments

Comments
 (0)