Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 11846df

Browse files
authored
Limit the number of in-flight /keys/query requests from a single device. (#10144)
1 parent 1bf83a1 commit 11846df

File tree

4 files changed

+196
-173
lines changed

4 files changed

+196
-173
lines changed

changelog.d/10144.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Limit the number of in-flight `/keys/query` requests from a single device.

synapse/handlers/e2e_keys.py

Lines changed: 181 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,15 @@ def __init__(self, hs: "HomeServer"):
7979
"client_keys", self.on_federation_query_client_keys
8080
)
8181

82+
# Limit the number of in-flight requests from a single device.
83+
self._query_devices_linearizer = Linearizer(
84+
name="query_devices",
85+
max_count=10,
86+
)
87+
8288
@trace
8389
async def query_devices(
84-
self, query_body: JsonDict, timeout: int, from_user_id: str
90+
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
8591
) -> JsonDict:
8692
"""Handle a device key query from a client
8793
@@ -105,191 +111,197 @@ async def query_devices(
105111
from_user_id: the user making the query. This is used when
106112
adding cross-signing signatures to limit what signatures users
107113
can see.
114+
from_device_id: the device making the query. This is used to limit
115+
the number of in-flight queries at a time.
108116
"""
109-
110-
device_keys_query = query_body.get(
111-
"device_keys", {}
112-
) # type: Dict[str, Iterable[str]]
113-
114-
# separate users by domain.
115-
# make a map from domain to user_id to device_ids
116-
local_query = {}
117-
remote_queries = {}
118-
119-
for user_id, device_ids in device_keys_query.items():
120-
# we use UserID.from_string to catch invalid user ids
121-
if self.is_mine(UserID.from_string(user_id)):
122-
local_query[user_id] = device_ids
123-
else:
124-
remote_queries[user_id] = device_ids
125-
126-
set_tag("local_key_query", local_query)
127-
set_tag("remote_key_query", remote_queries)
128-
129-
# First get local devices.
130-
# A map of destination -> failure response.
131-
failures = {} # type: Dict[str, JsonDict]
132-
results = {}
133-
if local_query:
134-
local_result = await self.query_local_devices(local_query)
135-
for user_id, keys in local_result.items():
136-
if user_id in local_query:
137-
results[user_id] = keys
138-
139-
# Get cached cross-signing keys
140-
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
141-
device_keys_query, from_user_id
142-
)
143-
144-
# Now attempt to get any remote devices from our local cache.
145-
# A map of destination -> user ID -> device IDs.
146-
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
147-
if remote_queries:
148-
query_list = [] # type: List[Tuple[str, Optional[str]]]
149-
for user_id, device_ids in remote_queries.items():
150-
if device_ids:
151-
query_list.extend((user_id, device_id) for device_id in device_ids)
117+
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
118+
device_keys_query = query_body.get(
119+
"device_keys", {}
120+
) # type: Dict[str, Iterable[str]]
121+
122+
# separate users by domain.
123+
# make a map from domain to user_id to device_ids
124+
local_query = {}
125+
remote_queries = {}
126+
127+
for user_id, device_ids in device_keys_query.items():
128+
# we use UserID.from_string to catch invalid user ids
129+
if self.is_mine(UserID.from_string(user_id)):
130+
local_query[user_id] = device_ids
152131
else:
153-
query_list.append((user_id, None))
154-
155-
(
156-
user_ids_not_in_cache,
157-
remote_results,
158-
) = await self.store.get_user_devices_from_cache(query_list)
159-
for user_id, devices in remote_results.items():
160-
user_devices = results.setdefault(user_id, {})
161-
for device_id, device in devices.items():
162-
keys = device.get("keys", None)
163-
device_display_name = device.get("device_display_name", None)
164-
if keys:
165-
result = dict(keys)
166-
unsigned = result.setdefault("unsigned", {})
167-
if device_display_name:
168-
unsigned["device_display_name"] = device_display_name
169-
user_devices[device_id] = result
170-
171-
# check for missing cross-signing keys.
172-
for user_id in remote_queries.keys():
173-
cached_cross_master = user_id in cross_signing_keys["master_keys"]
174-
cached_cross_selfsigning = (
175-
user_id in cross_signing_keys["self_signing_keys"]
176-
)
177-
178-
# check if we are missing only one of cross-signing master or
179-
# self-signing key, but the other one is cached.
180-
# as we need both, this will issue a federation request.
181-
# if we don't have any of the keys, either the user doesn't have
182-
# cross-signing set up, or the cached device list
183-
# is not (yet) updated.
184-
if cached_cross_master ^ cached_cross_selfsigning:
185-
user_ids_not_in_cache.add(user_id)
186-
187-
# add those users to the list to fetch over federation.
188-
for user_id in user_ids_not_in_cache:
189-
domain = get_domain_from_id(user_id)
190-
r = remote_queries_not_in_cache.setdefault(domain, {})
191-
r[user_id] = remote_queries[user_id]
192-
193-
# Now fetch any devices that we don't have in our cache
194-
@trace
195-
async def do_remote_query(destination):
196-
"""This is called when we are querying the device list of a user on
197-
a remote homeserver and their device list is not in the device list
198-
cache. If we share a room with this user and we're not querying for
199-
specific user we will update the cache with their device list.
200-
"""
201-
202-
destination_query = remote_queries_not_in_cache[destination]
203-
204-
# We first consider whether we wish to update the device list cache with
205-
# the users device list. We want to track a user's devices when the
206-
# authenticated user shares a room with the queried user and the query
207-
# has not specified a particular device.
208-
# If we update the cache for the queried user we remove them from further
209-
# queries. We use the more efficient batched query_client_keys for all
210-
# remaining users
211-
user_ids_updated = []
212-
for (user_id, device_list) in destination_query.items():
213-
if user_id in user_ids_updated:
214-
continue
215-
216-
if device_list:
217-
continue
132+
remote_queries[user_id] = device_ids
133+
134+
set_tag("local_key_query", local_query)
135+
set_tag("remote_key_query", remote_queries)
136+
137+
# First get local devices.
138+
# A map of destination -> failure response.
139+
failures = {} # type: Dict[str, JsonDict]
140+
results = {}
141+
if local_query:
142+
local_result = await self.query_local_devices(local_query)
143+
for user_id, keys in local_result.items():
144+
if user_id in local_query:
145+
results[user_id] = keys
218146

219-
room_ids = await self.store.get_rooms_for_user(user_id)
220-
if not room_ids:
221-
continue
147+
# Get cached cross-signing keys
148+
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
149+
device_keys_query, from_user_id
150+
)
222151

223-
# We've decided we're sharing a room with this user and should
224-
# probably be tracking their device lists. However, we haven't
225-
# done an initial sync on the device list so we do it now.
226-
try:
227-
if self._is_master:
228-
user_devices = await self.device_handler.device_list_updater.user_device_resync(
229-
user_id
152+
# Now attempt to get any remote devices from our local cache.
153+
# A map of destination -> user ID -> device IDs.
154+
remote_queries_not_in_cache = (
155+
{}
156+
) # type: Dict[str, Dict[str, Iterable[str]]]
157+
if remote_queries:
158+
query_list = [] # type: List[Tuple[str, Optional[str]]]
159+
for user_id, device_ids in remote_queries.items():
160+
if device_ids:
161+
query_list.extend(
162+
(user_id, device_id) for device_id in device_ids
230163
)
231164
else:
232-
user_devices = await self._user_device_resync_client(
233-
user_id=user_id
234-
)
235-
236-
user_devices = user_devices["devices"]
237-
user_results = results.setdefault(user_id, {})
238-
for device in user_devices:
239-
user_results[device["device_id"]] = device["keys"]
240-
user_ids_updated.append(user_id)
241-
except Exception as e:
242-
failures[destination] = _exception_to_failure(e)
243-
244-
if len(destination_query) == len(user_ids_updated):
245-
# We've updated all the users in the query and we do not need to
246-
# make any further remote calls.
247-
return
165+
query_list.append((user_id, None))
248166

249-
# Remove all the users from the query which we have updated
250-
for user_id in user_ids_updated:
251-
destination_query.pop(user_id)
167+
(
168+
user_ids_not_in_cache,
169+
remote_results,
170+
) = await self.store.get_user_devices_from_cache(query_list)
171+
for user_id, devices in remote_results.items():
172+
user_devices = results.setdefault(user_id, {})
173+
for device_id, device in devices.items():
174+
keys = device.get("keys", None)
175+
device_display_name = device.get("device_display_name", None)
176+
if keys:
177+
result = dict(keys)
178+
unsigned = result.setdefault("unsigned", {})
179+
if device_display_name:
180+
unsigned["device_display_name"] = device_display_name
181+
user_devices[device_id] = result
182+
183+
# check for missing cross-signing keys.
184+
for user_id in remote_queries.keys():
185+
cached_cross_master = user_id in cross_signing_keys["master_keys"]
186+
cached_cross_selfsigning = (
187+
user_id in cross_signing_keys["self_signing_keys"]
188+
)
252189

253-
try:
254-
remote_result = await self.federation.query_client_keys(
255-
destination, {"device_keys": destination_query}, timeout=timeout
256-
)
190+
# check if we are missing only one of cross-signing master or
191+
# self-signing key, but the other one is cached.
192+
# as we need both, this will issue a federation request.
193+
# if we don't have any of the keys, either the user doesn't have
194+
# cross-signing set up, or the cached device list
195+
# is not (yet) updated.
196+
if cached_cross_master ^ cached_cross_selfsigning:
197+
user_ids_not_in_cache.add(user_id)
198+
199+
# add those users to the list to fetch over federation.
200+
for user_id in user_ids_not_in_cache:
201+
domain = get_domain_from_id(user_id)
202+
r = remote_queries_not_in_cache.setdefault(domain, {})
203+
r[user_id] = remote_queries[user_id]
204+
205+
# Now fetch any devices that we don't have in our cache
206+
@trace
207+
async def do_remote_query(destination):
208+
"""This is called when we are querying the device list of a user on
209+
a remote homeserver and their device list is not in the device list
210+
cache. If we share a room with this user and we're not querying for
211+
specific user we will update the cache with their device list.
212+
"""
213+
214+
destination_query = remote_queries_not_in_cache[destination]
215+
216+
# We first consider whether we wish to update the device list cache with
217+
# the users device list. We want to track a user's devices when the
218+
# authenticated user shares a room with the queried user and the query
219+
# has not specified a particular device.
220+
# If we update the cache for the queried user we remove them from further
221+
# queries. We use the more efficient batched query_client_keys for all
222+
# remaining users
223+
user_ids_updated = []
224+
for (user_id, device_list) in destination_query.items():
225+
if user_id in user_ids_updated:
226+
continue
227+
228+
if device_list:
229+
continue
230+
231+
room_ids = await self.store.get_rooms_for_user(user_id)
232+
if not room_ids:
233+
continue
234+
235+
# We've decided we're sharing a room with this user and should
236+
# probably be tracking their device lists. However, we haven't
237+
# done an initial sync on the device list so we do it now.
238+
try:
239+
if self._is_master:
240+
user_devices = await self.device_handler.device_list_updater.user_device_resync(
241+
user_id
242+
)
243+
else:
244+
user_devices = await self._user_device_resync_client(
245+
user_id=user_id
246+
)
247+
248+
user_devices = user_devices["devices"]
249+
user_results = results.setdefault(user_id, {})
250+
for device in user_devices:
251+
user_results[device["device_id"]] = device["keys"]
252+
user_ids_updated.append(user_id)
253+
except Exception as e:
254+
failures[destination] = _exception_to_failure(e)
255+
256+
if len(destination_query) == len(user_ids_updated):
257+
# We've updated all the users in the query and we do not need to
258+
# make any further remote calls.
259+
return
260+
261+
# Remove all the users from the query which we have updated
262+
for user_id in user_ids_updated:
263+
destination_query.pop(user_id)
257264

258-
for user_id, keys in remote_result["device_keys"].items():
259-
if user_id in destination_query:
260-
results[user_id] = keys
265+
try:
266+
remote_result = await self.federation.query_client_keys(
267+
destination, {"device_keys": destination_query}, timeout=timeout
268+
)
261269

262-
if "master_keys" in remote_result:
263-
for user_id, key in remote_result["master_keys"].items():
270+
for user_id, keys in remote_result["device_keys"].items():
264271
if user_id in destination_query:
265-
cross_signing_keys["master_keys"][user_id] = key
272+
results[user_id] = keys
266273

267-
if "self_signing_keys" in remote_result:
268-
for user_id, key in remote_result["self_signing_keys"].items():
269-
if user_id in destination_query:
270-
cross_signing_keys["self_signing_keys"][user_id] = key
274+
if "master_keys" in remote_result:
275+
for user_id, key in remote_result["master_keys"].items():
276+
if user_id in destination_query:
277+
cross_signing_keys["master_keys"][user_id] = key
271278

272-
except Exception as e:
273-
failure = _exception_to_failure(e)
274-
failures[destination] = failure
275-
set_tag("error", True)
276-
set_tag("reason", failure)
279+
if "self_signing_keys" in remote_result:
280+
for user_id, key in remote_result["self_signing_keys"].items():
281+
if user_id in destination_query:
282+
cross_signing_keys["self_signing_keys"][user_id] = key
277283

278-
await make_deferred_yieldable(
279-
defer.gatherResults(
280-
[
281-
run_in_background(do_remote_query, destination)
282-
for destination in remote_queries_not_in_cache
283-
],
284-
consumeErrors=True,
285-
).addErrback(unwrapFirstError)
286-
)
284+
except Exception as e:
285+
failure = _exception_to_failure(e)
286+
failures[destination] = failure
287+
set_tag("error", True)
288+
set_tag("reason", failure)
289+
290+
await make_deferred_yieldable(
291+
defer.gatherResults(
292+
[
293+
run_in_background(do_remote_query, destination)
294+
for destination in remote_queries_not_in_cache
295+
],
296+
consumeErrors=True,
297+
).addErrback(unwrapFirstError)
298+
)
287299

288-
ret = {"device_keys": results, "failures": failures}
300+
ret = {"device_keys": results, "failures": failures}
289301

290-
ret.update(cross_signing_keys)
302+
ret.update(cross_signing_keys)
291303

292-
return ret
304+
return ret
293305

294306
async def get_cross_signing_keys_from_cache(
295307
self, query: Iterable[str], from_user_id: Optional[str]

synapse/rest/client/v2_alpha/keys.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,12 @@ def __init__(self, hs):
160160
async def on_POST(self, request):
161161
requester = await self.auth.get_user_by_req(request, allow_guest=True)
162162
user_id = requester.user.to_string()
163+
device_id = requester.device_id
163164
timeout = parse_integer(request, "timeout", 10 * 1000)
164165
body = parse_json_object_from_request(request)
165-
result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
166+
result = await self.e2e_keys_handler.query_devices(
167+
body, timeout, user_id, device_id
168+
)
166169
return 200, result
167170

168171

0 commit comments

Comments
 (0)