Skip to content

Commit

Permalink
Implement iterator logic in RawIter
Browse files Browse the repository at this point in the history
  • Loading branch information
Kestrer committed Jan 24, 2021
1 parent 7ee722e commit 810c043
Showing 1 changed file with 80 additions and 102 deletions.
182 changes: 80 additions & 102 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
use std::cell::UnsafeCell;
use std::fmt;
use std::iter::FusedIterator;
use std::marker::PhantomData;
use std::mem;
use std::mem::MaybeUninit;
use std::panic::UnwindSafe;
Expand Down Expand Up @@ -274,20 +273,7 @@ impl<T: Send> ThreadLocal<T> {
{
Iter {
thread_local: self,
yielded: 0,
bucket: 0,
bucket_size: 1,
index: 0,
}
}

fn raw_iter_mut(&mut self) -> RawIterMut<T> {
RawIterMut {
remaining: *self.values.get_mut(),
buckets: unsafe { *(&self.buckets as *const _ as *const [*mut Entry<T>; BUCKETS]) },
bucket: 0,
bucket_size: 1,
index: 0,
raw: RawIter::new(),
}
}

Expand All @@ -299,8 +285,8 @@ impl<T: Send> ThreadLocal<T> {
/// threads are currently accessing their associated values.
pub fn iter_mut(&mut self) -> IterMut<T> {
IterMut {
raw: self.raw_iter_mut(),
marker: PhantomData,
thread_local: self,
raw: RawIter::new(),
}
}

Expand All @@ -319,10 +305,10 @@ impl<T: Send> IntoIterator for ThreadLocal<T> {
type Item = T;
type IntoIter = IntoIter<T>;

fn into_iter(mut self) -> IntoIter<T> {
fn into_iter(self) -> IntoIter<T> {
IntoIter {
raw: self.raw_iter_mut(),
_thread_local: self,
thread_local: self,
raw: RawIter::new(),
}
}
}
Expand Down Expand Up @@ -361,22 +347,26 @@ impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {

impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}

/// Iterator over the contents of a `ThreadLocal`.
#[derive(Debug)]
pub struct Iter<'a, T: Send + Sync> {
thread_local: &'a ThreadLocal<T>,
struct RawIter {
yielded: usize,
bucket: usize,
bucket_size: usize,
index: usize,
}
impl RawIter {
fn new() -> Self {
Self {
yielded: 0,
bucket: 0,
bucket_size: 1,
index: 0,
}
}

impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
type Item = &'a T;

fn next(&mut self) -> Option<Self::Item> {
fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
while self.bucket < BUCKETS {
let bucket = unsafe { self.thread_local.buckets.get_unchecked(self.bucket) };
let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
let bucket = bucket.load(Ordering::Relaxed);

if !bucket.is_null() {
Expand All @@ -390,140 +380,128 @@ impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
}
}

if self.bucket != 0 {
self.bucket_size <<= 1;
}
self.bucket += 1;

self.index = 0;
self.next_bucket();
}
None
}

fn size_hint(&self) -> (usize, Option<usize>) {
let total = self.thread_local.values.load(Ordering::Acquire);
(total - self.yielded, None)
}
}
impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}

struct RawIterMut<T: Send> {
remaining: usize,
buckets: [*mut Entry<T>; BUCKETS],
bucket: usize,
bucket_size: usize,
index: usize,
}

impl<T: Send> Iterator for RawIterMut<T> {
type Item = *mut MaybeUninit<T>;

fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
fn next_mut<'a, T: Send>(
&mut self,
thread_local: &'a mut ThreadLocal<T>,
) -> Option<&'a mut Entry<T>> {
if *thread_local.values.get_mut() == self.yielded {
return None;
}

loop {
let bucket = unsafe { *self.buckets.get_unchecked(self.bucket) };
let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
let bucket = *bucket.get_mut();

if !bucket.is_null() {
while self.index < self.bucket_size {
let entry = unsafe { &mut *bucket.add(self.index) };
self.index += 1;
if *entry.present.get_mut() {
self.remaining -= 1;
return Some(entry.value.get());
self.yielded += 1;
return Some(entry);
}
}
}

if self.bucket != 0 {
self.bucket_size <<= 1;
}
self.bucket += 1;
self.next_bucket();
}
}

self.index = 0;
fn next_bucket(&mut self) {
if self.bucket != 0 {
self.bucket_size <<= 1;
}
self.bucket += 1;
self.index = 0;
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
let total = thread_local.values.load(Ordering::Acquire);
(total - self.yielded, None)
}
fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
let remaining = total - self.yielded;
(remaining, Some(remaining))
}
}

unsafe impl<T: Send> Send for RawIterMut<T> {}
unsafe impl<T: Send + Sync> Sync for RawIterMut<T> {}
/// Iterator over the contents of a `ThreadLocal`.
#[derive(Debug)]
pub struct Iter<'a, T: Send + Sync> {
thread_local: &'a ThreadLocal<T>,
raw: RawIter,
}

impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.raw.next(self.thread_local)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.raw.size_hint(self.thread_local)
}
}
impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}

/// Mutable iterator over the contents of a `ThreadLocal`.
pub struct IterMut<'a, T: Send> {
raw: RawIterMut<T>,
marker: PhantomData<&'a mut ThreadLocal<T>>,
thread_local: &'a mut ThreadLocal<T>,
raw: RawIter,
}

impl<'a, T: Send> Iterator for IterMut<'a, T> {
type Item = &'a mut T;

fn next(&mut self) -> Option<&'a mut T> {
self.raw
.next()
.map(|x| unsafe { &mut *(&mut *x).as_mut_ptr() })
.next_mut(self.thread_local)
.map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.raw.size_hint()
self.raw.size_hint_frozen(self.thread_local)
}
}

impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
impl<T: Send> FusedIterator for IterMut<'_, T> {}

// The Debug bound is technically unnecessary but makes the API more consistent and future-proof.
impl<T: Send + fmt::Debug> fmt::Debug for IterMut<'_, T> {
// Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to
// this thread's value that potentially aliases with a mutable reference we have given out.
impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IterMut")
.field("remaining", &self.raw.remaining)
.field("bucket", &self.raw.bucket)
.field("bucket_size", &self.raw.bucket_size)
.field("index", &self.raw.index)
.finish()
f.debug_struct("IterMut").field("raw", &self.raw).finish()
}
}

/// An iterator that moves out of a `ThreadLocal`.
#[derive(Debug)]
pub struct IntoIter<T: Send> {
raw: RawIterMut<T>,
_thread_local: ThreadLocal<T>,
thread_local: ThreadLocal<T>,
raw: RawIter,
}

impl<T: Send> Iterator for IntoIter<T> {
type Item = T;

fn next(&mut self) -> Option<T> {
self.raw
.next()
.map(|x| unsafe { std::mem::replace(&mut *x, MaybeUninit::uninit()).assume_init() })
self.raw.next_mut(&mut self.thread_local).map(|entry| {
*entry.present.get_mut() = false;
unsafe {
std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
}
})
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.raw.size_hint()
self.raw.size_hint_frozen(&self.thread_local)
}
}

impl<T: Send> ExactSizeIterator for IntoIter<T> {}
impl<T: Send> FusedIterator for IntoIter<T> {}

// The Debug bound is technically unnecessary but makes the API more consistent and future-proof.
impl<T: Send + fmt::Debug> fmt::Debug for IntoIter<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IntoIter")
.field("remaining", &self.raw.remaining)
.field("bucket", &self.raw.bucket)
.field("bucket_size", &self.raw.bucket_size)
.field("index", &self.raw.index)
.finish()
}
}

fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
Box::into_raw(
(0..size)
Expand Down

0 comments on commit 810c043

Please sign in to comment.