@@ -79,9 +79,15 @@ def __init__(self, hs: "HomeServer"):
79
79
"client_keys" , self .on_federation_query_client_keys
80
80
)
81
81
82
+ # Limit the number of in-flight requests from a single device.
83
+ self ._query_devices_linearizer = Linearizer (
84
+ name = "query_devices" ,
85
+ max_count = 10 ,
86
+ )
87
+
82
88
@trace
83
89
async def query_devices (
84
- self , query_body : JsonDict , timeout : int , from_user_id : str
90
+ self , query_body : JsonDict , timeout : int , from_user_id : str , from_device_id : str
85
91
) -> JsonDict :
86
92
"""Handle a device key query from a client
87
93
@@ -105,191 +111,197 @@ async def query_devices(
105
111
from_user_id: the user making the query. This is used when
106
112
adding cross-signing signatures to limit what signatures users
107
113
can see.
114
+ from_device_id: the device making the query. This is used to limit
115
+ the number of in-flight queries at a time.
108
116
"""
109
-
110
- device_keys_query = query_body .get (
111
- "device_keys" , {}
112
- ) # type: Dict[str, Iterable[str]]
113
-
114
- # separate users by domain.
115
- # make a map from domain to user_id to device_ids
116
- local_query = {}
117
- remote_queries = {}
118
-
119
- for user_id , device_ids in device_keys_query .items ():
120
- # we use UserID.from_string to catch invalid user ids
121
- if self .is_mine (UserID .from_string (user_id )):
122
- local_query [user_id ] = device_ids
123
- else :
124
- remote_queries [user_id ] = device_ids
125
-
126
- set_tag ("local_key_query" , local_query )
127
- set_tag ("remote_key_query" , remote_queries )
128
-
129
- # First get local devices.
130
- # A map of destination -> failure response.
131
- failures = {} # type: Dict[str, JsonDict]
132
- results = {}
133
- if local_query :
134
- local_result = await self .query_local_devices (local_query )
135
- for user_id , keys in local_result .items ():
136
- if user_id in local_query :
137
- results [user_id ] = keys
138
-
139
- # Get cached cross-signing keys
140
- cross_signing_keys = await self .get_cross_signing_keys_from_cache (
141
- device_keys_query , from_user_id
142
- )
143
-
144
- # Now attempt to get any remote devices from our local cache.
145
- # A map of destination -> user ID -> device IDs.
146
- remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
147
- if remote_queries :
148
- query_list = [] # type: List[Tuple[str, Optional[str]]]
149
- for user_id , device_ids in remote_queries .items ():
150
- if device_ids :
151
- query_list .extend ((user_id , device_id ) for device_id in device_ids )
117
+ with await self ._query_devices_linearizer .queue ((from_user_id , from_device_id )):
118
+ device_keys_query = query_body .get (
119
+ "device_keys" , {}
120
+ ) # type: Dict[str, Iterable[str]]
121
+
122
+ # separate users by domain.
123
+ # make a map from domain to user_id to device_ids
124
+ local_query = {}
125
+ remote_queries = {}
126
+
127
+ for user_id , device_ids in device_keys_query .items ():
128
+ # we use UserID.from_string to catch invalid user ids
129
+ if self .is_mine (UserID .from_string (user_id )):
130
+ local_query [user_id ] = device_ids
152
131
else :
153
- query_list .append ((user_id , None ))
154
-
155
- (
156
- user_ids_not_in_cache ,
157
- remote_results ,
158
- ) = await self .store .get_user_devices_from_cache (query_list )
159
- for user_id , devices in remote_results .items ():
160
- user_devices = results .setdefault (user_id , {})
161
- for device_id , device in devices .items ():
162
- keys = device .get ("keys" , None )
163
- device_display_name = device .get ("device_display_name" , None )
164
- if keys :
165
- result = dict (keys )
166
- unsigned = result .setdefault ("unsigned" , {})
167
- if device_display_name :
168
- unsigned ["device_display_name" ] = device_display_name
169
- user_devices [device_id ] = result
170
-
171
- # check for missing cross-signing keys.
172
- for user_id in remote_queries .keys ():
173
- cached_cross_master = user_id in cross_signing_keys ["master_keys" ]
174
- cached_cross_selfsigning = (
175
- user_id in cross_signing_keys ["self_signing_keys" ]
176
- )
177
-
178
- # check if we are missing only one of cross-signing master or
179
- # self-signing key, but the other one is cached.
180
- # as we need both, this will issue a federation request.
181
- # if we don't have any of the keys, either the user doesn't have
182
- # cross-signing set up, or the cached device list
183
- # is not (yet) updated.
184
- if cached_cross_master ^ cached_cross_selfsigning :
185
- user_ids_not_in_cache .add (user_id )
186
-
187
- # add those users to the list to fetch over federation.
188
- for user_id in user_ids_not_in_cache :
189
- domain = get_domain_from_id (user_id )
190
- r = remote_queries_not_in_cache .setdefault (domain , {})
191
- r [user_id ] = remote_queries [user_id ]
192
-
193
- # Now fetch any devices that we don't have in our cache
194
- @trace
195
- async def do_remote_query (destination ):
196
- """This is called when we are querying the device list of a user on
197
- a remote homeserver and their device list is not in the device list
198
- cache. If we share a room with this user and we're not querying for
199
- specific user we will update the cache with their device list.
200
- """
201
-
202
- destination_query = remote_queries_not_in_cache [destination ]
203
-
204
- # We first consider whether we wish to update the device list cache with
205
- # the users device list. We want to track a user's devices when the
206
- # authenticated user shares a room with the queried user and the query
207
- # has not specified a particular device.
208
- # If we update the cache for the queried user we remove them from further
209
- # queries. We use the more efficient batched query_client_keys for all
210
- # remaining users
211
- user_ids_updated = []
212
- for (user_id , device_list ) in destination_query .items ():
213
- if user_id in user_ids_updated :
214
- continue
215
-
216
- if device_list :
217
- continue
132
+ remote_queries [user_id ] = device_ids
133
+
134
+ set_tag ("local_key_query" , local_query )
135
+ set_tag ("remote_key_query" , remote_queries )
136
+
137
+ # First get local devices.
138
+ # A map of destination -> failure response.
139
+ failures = {} # type: Dict[str, JsonDict]
140
+ results = {}
141
+ if local_query :
142
+ local_result = await self .query_local_devices (local_query )
143
+ for user_id , keys in local_result .items ():
144
+ if user_id in local_query :
145
+ results [user_id ] = keys
218
146
219
- room_ids = await self .store .get_rooms_for_user (user_id )
220
- if not room_ids :
221
- continue
147
+ # Get cached cross-signing keys
148
+ cross_signing_keys = await self .get_cross_signing_keys_from_cache (
149
+ device_keys_query , from_user_id
150
+ )
222
151
223
- # We've decided we're sharing a room with this user and should
224
- # probably be tracking their device lists. However, we haven't
225
- # done an initial sync on the device list so we do it now.
226
- try :
227
- if self ._is_master :
228
- user_devices = await self .device_handler .device_list_updater .user_device_resync (
229
- user_id
152
+ # Now attempt to get any remote devices from our local cache.
153
+ # A map of destination -> user ID -> device IDs.
154
+ remote_queries_not_in_cache = (
155
+ {}
156
+ ) # type: Dict[str, Dict[str, Iterable[str]]]
157
+ if remote_queries :
158
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
159
+ for user_id , device_ids in remote_queries .items ():
160
+ if device_ids :
161
+ query_list .extend (
162
+ (user_id , device_id ) for device_id in device_ids
230
163
)
231
164
else :
232
- user_devices = await self ._user_device_resync_client (
233
- user_id = user_id
234
- )
235
-
236
- user_devices = user_devices ["devices" ]
237
- user_results = results .setdefault (user_id , {})
238
- for device in user_devices :
239
- user_results [device ["device_id" ]] = device ["keys" ]
240
- user_ids_updated .append (user_id )
241
- except Exception as e :
242
- failures [destination ] = _exception_to_failure (e )
243
-
244
- if len (destination_query ) == len (user_ids_updated ):
245
- # We've updated all the users in the query and we do not need to
246
- # make any further remote calls.
247
- return
165
+ query_list .append ((user_id , None ))
248
166
249
- # Remove all the users from the query which we have updated
250
- for user_id in user_ids_updated :
251
- destination_query .pop (user_id )
167
+ (
168
+ user_ids_not_in_cache ,
169
+ remote_results ,
170
+ ) = await self .store .get_user_devices_from_cache (query_list )
171
+ for user_id , devices in remote_results .items ():
172
+ user_devices = results .setdefault (user_id , {})
173
+ for device_id , device in devices .items ():
174
+ keys = device .get ("keys" , None )
175
+ device_display_name = device .get ("device_display_name" , None )
176
+ if keys :
177
+ result = dict (keys )
178
+ unsigned = result .setdefault ("unsigned" , {})
179
+ if device_display_name :
180
+ unsigned ["device_display_name" ] = device_display_name
181
+ user_devices [device_id ] = result
182
+
183
+ # check for missing cross-signing keys.
184
+ for user_id in remote_queries .keys ():
185
+ cached_cross_master = user_id in cross_signing_keys ["master_keys" ]
186
+ cached_cross_selfsigning = (
187
+ user_id in cross_signing_keys ["self_signing_keys" ]
188
+ )
252
189
253
- try :
254
- remote_result = await self .federation .query_client_keys (
255
- destination , {"device_keys" : destination_query }, timeout = timeout
256
- )
190
+ # check if we are missing only one of cross-signing master or
191
+ # self-signing key, but the other one is cached.
192
+ # as we need both, this will issue a federation request.
193
+ # if we don't have any of the keys, either the user doesn't have
194
+ # cross-signing set up, or the cached device list
195
+ # is not (yet) updated.
196
+ if cached_cross_master ^ cached_cross_selfsigning :
197
+ user_ids_not_in_cache .add (user_id )
198
+
199
+ # add those users to the list to fetch over federation.
200
+ for user_id in user_ids_not_in_cache :
201
+ domain = get_domain_from_id (user_id )
202
+ r = remote_queries_not_in_cache .setdefault (domain , {})
203
+ r [user_id ] = remote_queries [user_id ]
204
+
205
+ # Now fetch any devices that we don't have in our cache
206
+ @trace
207
+ async def do_remote_query (destination ):
208
+ """This is called when we are querying the device list of a user on
209
+ a remote homeserver and their device list is not in the device list
210
+ cache. If we share a room with this user and we're not querying for
211
+ specific user we will update the cache with their device list.
212
+ """
213
+
214
+ destination_query = remote_queries_not_in_cache [destination ]
215
+
216
+ # We first consider whether we wish to update the device list cache with
217
+ # the users device list. We want to track a user's devices when the
218
+ # authenticated user shares a room with the queried user and the query
219
+ # has not specified a particular device.
220
+ # If we update the cache for the queried user we remove them from further
221
+ # queries. We use the more efficient batched query_client_keys for all
222
+ # remaining users
223
+ user_ids_updated = []
224
+ for (user_id , device_list ) in destination_query .items ():
225
+ if user_id in user_ids_updated :
226
+ continue
227
+
228
+ if device_list :
229
+ continue
230
+
231
+ room_ids = await self .store .get_rooms_for_user (user_id )
232
+ if not room_ids :
233
+ continue
234
+
235
+ # We've decided we're sharing a room with this user and should
236
+ # probably be tracking their device lists. However, we haven't
237
+ # done an initial sync on the device list so we do it now.
238
+ try :
239
+ if self ._is_master :
240
+ user_devices = await self .device_handler .device_list_updater .user_device_resync (
241
+ user_id
242
+ )
243
+ else :
244
+ user_devices = await self ._user_device_resync_client (
245
+ user_id = user_id
246
+ )
247
+
248
+ user_devices = user_devices ["devices" ]
249
+ user_results = results .setdefault (user_id , {})
250
+ for device in user_devices :
251
+ user_results [device ["device_id" ]] = device ["keys" ]
252
+ user_ids_updated .append (user_id )
253
+ except Exception as e :
254
+ failures [destination ] = _exception_to_failure (e )
255
+
256
+ if len (destination_query ) == len (user_ids_updated ):
257
+ # We've updated all the users in the query and we do not need to
258
+ # make any further remote calls.
259
+ return
260
+
261
+ # Remove all the users from the query which we have updated
262
+ for user_id in user_ids_updated :
263
+ destination_query .pop (user_id )
257
264
258
- for user_id , keys in remote_result ["device_keys" ].items ():
259
- if user_id in destination_query :
260
- results [user_id ] = keys
265
+ try :
266
+ remote_result = await self .federation .query_client_keys (
267
+ destination , {"device_keys" : destination_query }, timeout = timeout
268
+ )
261
269
262
- if "master_keys" in remote_result :
263
- for user_id , key in remote_result ["master_keys" ].items ():
270
+ for user_id , keys in remote_result ["device_keys" ].items ():
264
271
if user_id in destination_query :
265
- cross_signing_keys [ "master_keys" ][ user_id ] = key
272
+ results [ user_id ] = keys
266
273
267
- if "self_signing_keys " in remote_result :
268
- for user_id , key in remote_result ["self_signing_keys " ].items ():
269
- if user_id in destination_query :
270
- cross_signing_keys ["self_signing_keys " ][user_id ] = key
274
+ if "master_keys " in remote_result :
275
+ for user_id , key in remote_result ["master_keys " ].items ():
276
+ if user_id in destination_query :
277
+ cross_signing_keys ["master_keys " ][user_id ] = key
271
278
272
- except Exception as e :
273
- failure = _exception_to_failure (e )
274
- failures [destination ] = failure
275
- set_tag ("error" , True )
276
- set_tag ("reason" , failure )
279
+ if "self_signing_keys" in remote_result :
280
+ for user_id , key in remote_result ["self_signing_keys" ].items ():
281
+ if user_id in destination_query :
282
+ cross_signing_keys ["self_signing_keys" ][user_id ] = key
277
283
278
- await make_deferred_yieldable (
279
- defer .gatherResults (
280
- [
281
- run_in_background (do_remote_query , destination )
282
- for destination in remote_queries_not_in_cache
283
- ],
284
- consumeErrors = True ,
285
- ).addErrback (unwrapFirstError )
286
- )
284
+ except Exception as e :
285
+ failure = _exception_to_failure (e )
286
+ failures [destination ] = failure
287
+ set_tag ("error" , True )
288
+ set_tag ("reason" , failure )
289
+
290
+ await make_deferred_yieldable (
291
+ defer .gatherResults (
292
+ [
293
+ run_in_background (do_remote_query , destination )
294
+ for destination in remote_queries_not_in_cache
295
+ ],
296
+ consumeErrors = True ,
297
+ ).addErrback (unwrapFirstError )
298
+ )
287
299
288
- ret = {"device_keys" : results , "failures" : failures }
300
+ ret = {"device_keys" : results , "failures" : failures }
289
301
290
- ret .update (cross_signing_keys )
302
+ ret .update (cross_signing_keys )
291
303
292
- return ret
304
+ return ret
293
305
294
306
async def get_cross_signing_keys_from_cache (
295
307
self , query : Iterable [str ], from_user_id : Optional [str ]
0 commit comments