@@ -161,7 +161,7 @@ static THE_REGISTRY_SET: Once = Once::new();
161161/// Starts the worker threads (if that has not already happened). If
162162/// initialization has not already occurred, use the default
163163/// configuration.
164- pub ( super ) fn global_registry ( ) -> & ' static Arc < Registry > {
164+ fn global_registry ( ) -> & ' static Arc < Registry > {
165165 set_global_registry ( default_global_registry)
166166 . or_else ( |err| unsafe { THE_REGISTRY . as_ref ( ) . ok_or ( err) } )
167167 . expect ( "The global thread pool has not been initialized." )
@@ -217,6 +217,36 @@ fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
217217 result
218218}
219219
220+ // This is used to temporarily overwrite the current registry.
221+ //
222+ // This either null, a pointer to the global registry if it was
223+ // ever used to access the global registry or a pointer to a
224+ // registry which is temporarily made current because the current
225+ // thread is not a worker thread but is running a scope associated
226+ // to a specific thread pool.
227+ thread_local ! {
228+ static CURRENT_REGISTRY : Cell <* const Arc <Registry >> = const { Cell :: new( ptr:: null( ) ) } ;
229+ }
230+
231+ #[ cold]
232+ fn set_current_registry_to_global_registry ( ) -> * const Arc < Registry > {
233+ let global = global_registry ( ) ;
234+
235+ CURRENT_REGISTRY . with ( |current_registry| current_registry. set ( global) ) ;
236+
237+ global
238+ }
239+
240+ pub ( super ) fn current_registry ( ) -> * const Arc < Registry > {
241+ let mut current = CURRENT_REGISTRY . with ( Cell :: get) ;
242+
243+ if current. is_null ( ) {
244+ current = set_current_registry_to_global_registry ( ) ;
245+ }
246+
247+ current
248+ }
249+
220250struct Terminator < ' a > ( & ' a Arc < Registry > ) ;
221251
222252impl < ' a > Drop for Terminator < ' a > {
@@ -315,22 +345,55 @@ impl Registry {
315345 unsafe {
316346 let worker_thread = WorkerThread :: current ( ) ;
317347 let registry = if worker_thread. is_null ( ) {
318- global_registry ( )
348+ & * current_registry ( )
319349 } else {
320350 & ( * worker_thread) . registry
321351 } ;
322352 Arc :: clone ( registry)
323353 }
324354 }
325355
356+ /// Optionally install a specific registry as the current one.
357+ ///
358+ /// This is used when a thread which is not a worker executes
359+ /// a scope which should use the specific thread pool instead of
360+ /// the global one.
361+ pub ( super ) fn with_current < F , R > ( registry : Option < & Arc < Registry > > , f : F ) -> R
362+ where
363+ F : FnOnce ( ) -> R ,
364+ {
365+ struct Guard {
366+ current : * const Arc < Registry > ,
367+ }
368+
369+ impl Guard {
370+ fn new ( registry : & Arc < Registry > ) -> Self {
371+ let current =
372+ CURRENT_REGISTRY . with ( |current_registry| current_registry. replace ( registry) ) ;
373+
374+ Self { current }
375+ }
376+ }
377+
378+ impl Drop for Guard {
379+ fn drop ( & mut self ) {
380+ CURRENT_REGISTRY . with ( |current_registry| current_registry. set ( self . current ) ) ;
381+ }
382+ }
383+
384+ let _guard = registry. map ( Guard :: new) ;
385+
386+ f ( )
387+ }
388+
326389 /// Returns the number of threads in the current registry. This
327390 /// is better than `Registry::current().num_threads()` because it
328391 /// avoids incrementing the `Arc`.
329392 pub ( super ) fn current_num_threads ( ) -> usize {
330393 unsafe {
331394 let worker_thread = WorkerThread :: current ( ) ;
332395 if worker_thread. is_null ( ) {
333- global_registry ( ) . num_threads ( )
396+ ( * current_registry ( ) ) . num_threads ( )
334397 } else {
335398 ( * worker_thread) . registry . num_threads ( )
336399 }
@@ -946,7 +1009,7 @@ where
9461009 // invalidated until we return.
9471010 op ( & * owner_thread, false )
9481011 } else {
949- global_registry ( ) . in_worker ( op)
1012+ ( * current_registry ( ) ) . in_worker ( op)
9501013 }
9511014 }
9521015}
0 commit comments