Skip to content

Commit dca4007

Browse files
authored
Merge pull request #29 from Kestrer/raw-iter
2 parents 7ee722e + 33ad405 commit dca4007

File tree

1 file changed

+82
-102
lines changed

1 file changed

+82
-102
lines changed

src/lib.rs

Lines changed: 82 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
7676
use std::cell::UnsafeCell;
7777
use std::fmt;
7878
use std::iter::FusedIterator;
79-
use std::marker::PhantomData;
8079
use std::mem;
8180
use std::mem::MaybeUninit;
8281
use std::panic::UnwindSafe;
@@ -274,20 +273,7 @@ impl<T: Send> ThreadLocal<T> {
274273
{
275274
Iter {
276275
thread_local: self,
277-
yielded: 0,
278-
bucket: 0,
279-
bucket_size: 1,
280-
index: 0,
281-
}
282-
}
283-
284-
fn raw_iter_mut(&mut self) -> RawIterMut<T> {
285-
RawIterMut {
286-
remaining: *self.values.get_mut(),
287-
buckets: unsafe { *(&self.buckets as *const _ as *const [*mut Entry<T>; BUCKETS]) },
288-
bucket: 0,
289-
bucket_size: 1,
290-
index: 0,
276+
raw: RawIter::new(),
291277
}
292278
}
293279

@@ -299,8 +285,8 @@ impl<T: Send> ThreadLocal<T> {
299285
/// threads are currently accessing their associated values.
300286
pub fn iter_mut(&mut self) -> IterMut<T> {
301287
IterMut {
302-
raw: self.raw_iter_mut(),
303-
marker: PhantomData,
288+
thread_local: self,
289+
raw: RawIter::new(),
304290
}
305291
}
306292

@@ -319,10 +305,10 @@ impl<T: Send> IntoIterator for ThreadLocal<T> {
319305
type Item = T;
320306
type IntoIter = IntoIter<T>;
321307

322-
fn into_iter(mut self) -> IntoIter<T> {
308+
fn into_iter(self) -> IntoIter<T> {
323309
IntoIter {
324-
raw: self.raw_iter_mut(),
325-
_thread_local: self,
310+
thread_local: self,
311+
raw: RawIter::new(),
326312
}
327313
}
328314
}
@@ -361,22 +347,27 @@ impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
361347

362348
impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
363349

364-
/// Iterator over the contents of a `ThreadLocal`.
365350
#[derive(Debug)]
366-
pub struct Iter<'a, T: Send + Sync> {
367-
thread_local: &'a ThreadLocal<T>,
351+
struct RawIter {
368352
yielded: usize,
369353
bucket: usize,
370354
bucket_size: usize,
371355
index: usize,
372356
}
357+
impl RawIter {
358+
#[inline]
359+
fn new() -> Self {
360+
Self {
361+
yielded: 0,
362+
bucket: 0,
363+
bucket_size: 1,
364+
index: 0,
365+
}
366+
}
373367

374-
impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
375-
type Item = &'a T;
376-
377-
fn next(&mut self) -> Option<Self::Item> {
368+
fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
378369
while self.bucket < BUCKETS {
379-
let bucket = unsafe { self.thread_local.buckets.get_unchecked(self.bucket) };
370+
let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
380371
let bucket = bucket.load(Ordering::Relaxed);
381372

382373
if !bucket.is_null() {
@@ -390,140 +381,129 @@ impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
390381
}
391382
}
392383

393-
if self.bucket != 0 {
394-
self.bucket_size <<= 1;
395-
}
396-
self.bucket += 1;
397-
398-
self.index = 0;
384+
self.next_bucket();
399385
}
400386
None
401387
}
402-
403-
fn size_hint(&self) -> (usize, Option<usize>) {
404-
let total = self.thread_local.values.load(Ordering::Acquire);
405-
(total - self.yielded, None)
406-
}
407-
}
408-
impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
409-
410-
struct RawIterMut<T: Send> {
411-
remaining: usize,
412-
buckets: [*mut Entry<T>; BUCKETS],
413-
bucket: usize,
414-
bucket_size: usize,
415-
index: usize,
416-
}
417-
418-
impl<T: Send> Iterator for RawIterMut<T> {
419-
type Item = *mut MaybeUninit<T>;
420-
421-
fn next(&mut self) -> Option<Self::Item> {
422-
if self.remaining == 0 {
388+
fn next_mut<'a, T: Send>(
389+
&mut self,
390+
thread_local: &'a mut ThreadLocal<T>,
391+
) -> Option<&'a mut Entry<T>> {
392+
if *thread_local.values.get_mut() == self.yielded {
423393
return None;
424394
}
425395

426396
loop {
427-
let bucket = unsafe { *self.buckets.get_unchecked(self.bucket) };
397+
let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
398+
let bucket = *bucket.get_mut();
428399

429400
if !bucket.is_null() {
430401
while self.index < self.bucket_size {
431402
let entry = unsafe { &mut *bucket.add(self.index) };
432403
self.index += 1;
433404
if *entry.present.get_mut() {
434-
self.remaining -= 1;
435-
return Some(entry.value.get());
405+
self.yielded += 1;
406+
return Some(entry);
436407
}
437408
}
438409
}
439410

440-
if self.bucket != 0 {
441-
self.bucket_size <<= 1;
442-
}
443-
self.bucket += 1;
411+
self.next_bucket();
412+
}
413+
}
444414

445-
self.index = 0;
415+
#[inline]
416+
fn next_bucket(&mut self) {
417+
if self.bucket != 0 {
418+
self.bucket_size <<= 1;
446419
}
420+
self.bucket += 1;
421+
self.index = 0;
447422
}
448423

449-
fn size_hint(&self) -> (usize, Option<usize>) {
450-
(self.remaining, Some(self.remaining))
424+
fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
425+
let total = thread_local.values.load(Ordering::Acquire);
426+
(total - self.yielded, None)
427+
}
428+
fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
429+
let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
430+
let remaining = total - self.yielded;
431+
(remaining, Some(remaining))
451432
}
452433
}
453434

454-
unsafe impl<T: Send> Send for RawIterMut<T> {}
455-
unsafe impl<T: Send + Sync> Sync for RawIterMut<T> {}
435+
/// Iterator over the contents of a `ThreadLocal`.
436+
#[derive(Debug)]
437+
pub struct Iter<'a, T: Send + Sync> {
438+
thread_local: &'a ThreadLocal<T>,
439+
raw: RawIter,
440+
}
441+
442+
impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
443+
type Item = &'a T;
444+
fn next(&mut self) -> Option<Self::Item> {
445+
self.raw.next(self.thread_local)
446+
}
447+
fn size_hint(&self) -> (usize, Option<usize>) {
448+
self.raw.size_hint(self.thread_local)
449+
}
450+
}
451+
impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
456452

457453
/// Mutable iterator over the contents of a `ThreadLocal`.
458454
pub struct IterMut<'a, T: Send> {
459-
raw: RawIterMut<T>,
460-
marker: PhantomData<&'a mut ThreadLocal<T>>,
455+
thread_local: &'a mut ThreadLocal<T>,
456+
raw: RawIter,
461457
}
462458

463459
impl<'a, T: Send> Iterator for IterMut<'a, T> {
464460
type Item = &'a mut T;
465-
466461
fn next(&mut self) -> Option<&'a mut T> {
467462
self.raw
468-
.next()
469-
.map(|x| unsafe { &mut *(&mut *x).as_mut_ptr() })
463+
.next_mut(self.thread_local)
464+
.map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
470465
}
471-
472466
fn size_hint(&self) -> (usize, Option<usize>) {
473-
self.raw.size_hint()
467+
self.raw.size_hint_frozen(self.thread_local)
474468
}
475469
}
476470

477471
impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
478472
impl<T: Send> FusedIterator for IterMut<'_, T> {}
479473

480-
// The Debug bound is technically unnecessary but makes the API more consistent and future-proof.
481-
impl<T: Send + fmt::Debug> fmt::Debug for IterMut<'_, T> {
474+
// Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to
475+
// this thread's value that potentially aliases with a mutable reference we have given out.
476+
impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
482477
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
483-
f.debug_struct("IterMut")
484-
.field("remaining", &self.raw.remaining)
485-
.field("bucket", &self.raw.bucket)
486-
.field("bucket_size", &self.raw.bucket_size)
487-
.field("index", &self.raw.index)
488-
.finish()
478+
f.debug_struct("IterMut").field("raw", &self.raw).finish()
489479
}
490480
}
491481

492482
/// An iterator that moves out of a `ThreadLocal`.
483+
#[derive(Debug)]
493484
pub struct IntoIter<T: Send> {
494-
raw: RawIterMut<T>,
495-
_thread_local: ThreadLocal<T>,
485+
thread_local: ThreadLocal<T>,
486+
raw: RawIter,
496487
}
497488

498489
impl<T: Send> Iterator for IntoIter<T> {
499490
type Item = T;
500-
501491
fn next(&mut self) -> Option<T> {
502-
self.raw
503-
.next()
504-
.map(|x| unsafe { std::mem::replace(&mut *x, MaybeUninit::uninit()).assume_init() })
492+
self.raw.next_mut(&mut self.thread_local).map(|entry| {
493+
*entry.present.get_mut() = false;
494+
unsafe {
495+
std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
496+
}
497+
})
505498
}
506-
507499
fn size_hint(&self) -> (usize, Option<usize>) {
508-
self.raw.size_hint()
500+
self.raw.size_hint_frozen(&self.thread_local)
509501
}
510502
}
511503

512504
impl<T: Send> ExactSizeIterator for IntoIter<T> {}
513505
impl<T: Send> FusedIterator for IntoIter<T> {}
514506

515-
// The Debug bound is technically unnecessary but makes the API more consistent and future-proof.
516-
impl<T: Send + fmt::Debug> fmt::Debug for IntoIter<T> {
517-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
518-
f.debug_struct("IntoIter")
519-
.field("remaining", &self.raw.remaining)
520-
.field("bucket", &self.raw.bucket)
521-
.field("bucket_size", &self.raw.bucket_size)
522-
.field("index", &self.raw.index)
523-
.finish()
524-
}
525-
}
526-
527507
fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
528508
Box::into_raw(
529509
(0..size)

0 commit comments

Comments
 (0)