From ea093d4a932567241ab2c04b4ea87a03712f873d Mon Sep 17 00:00:00 2001 From: Zachary Dremann Date: Sat, 6 Apr 2024 21:31:24 -0400 Subject: [PATCH] Add for_each to Bitmap too --- croaring/benches/benches.rs | 54 +++++++++++++++++++++++++--------- croaring/src/bitmap/imp.rs | 48 +++++++++++++++++++++++++++++-- croaring/src/bitmap64/imp.rs | 43 ++++----------------------- croaring/src/callback.rs | 56 ++++++++++++++++++++++++++++++++++++ croaring/src/lib.rs | 1 + 5 files changed, 149 insertions(+), 53 deletions(-) create mode 100644 croaring/src/callback.rs diff --git a/croaring/benches/benches.rs b/croaring/benches/benches.rs index 31638ce..e362f33 100644 --- a/croaring/benches/benches.rs +++ b/croaring/benches/benches.rs @@ -131,10 +131,25 @@ fn flip(c: &mut Criterion) { } fn to_vec(c: &mut Criterion) { - c.bench_function("to_vec", |b| { - let bitmap = Bitmap::of(&[1, 2, 3]); + const N: usize = 100_000; + let bitmap: Bitmap = random_iter().take(N).collect(); + let mut g = c.benchmark_group("collect"); + g.bench_function("to_vec", |b| { b.iter(|| bitmap.to_vec()); }); + g.bench_function("via_iter", |b| { + b.iter(|| bitmap.iter().collect::>()); + }); + g.bench_function("foreach", |b| { + b.iter(|| { + let mut vec = Vec::with_capacity(bitmap.cardinality() as usize); + bitmap.for_each(|item| -> ControlFlow<()> { + vec.push(item); + ControlFlow::Continue(()) + }); + vec + }); + }); } fn get_serialized_size_in_bytes(c: &mut Criterion) { @@ -214,7 +229,27 @@ fn bulk_new(c: &mut Criterion) { group.finish(); } -fn random_iter(c: &mut Criterion) { +#[derive(Clone)] +struct RandomIter { + x: u32, +} + +impl Iterator for RandomIter { + type Item = u32; + + fn next(&mut self) -> Option { + const MULTIPLIER: u32 = 742938285; + const MODULUS: u32 = (1 << 31) - 1; + self.x = (MULTIPLIER.wrapping_mul(self.x)) % MODULUS; + Some(self.x) + } +} + +fn random_iter() -> RandomIter { + RandomIter { x: 20170705 } +} + +fn create_random(c: &mut Criterion) { const N: u32 = 5_000; // Clamp values so we get some re-use of containers const MAX: u32 = 8 * (u16::MAX as u32 + 1); @@ -222,16 +257,7 @@ fn random_iter(c: &mut Criterion) { let mut group = c.benchmark_group("random_iter"); group.throughput(Throughput::Elements(N.into())); - let rand_iter = { - const MULTIPLIER: u32 = 742938285; - const MODULUS: u32 = (1 << 31) - 1; - // Super simple LCG iterator - let mut z = 20170705; // seed - std::iter::from_fn(move || { - z = (MULTIPLIER.wrapping_mul(z)) % MODULUS; - Some(z % MAX) - }) - }; + let rand_iter = random_iter(); group.bench_function("random_adds", |b| { b.iter(|| { @@ -360,7 +386,7 @@ criterion_group!( serialize, deserialize, bulk_new, - random_iter, + create_random, collect_bitmap64_to_vec, iterate_bitmap64, ); diff --git a/croaring/src/bitmap/imp.rs b/croaring/src/bitmap/imp.rs index d89d8fa..0a0ffa3 100644 --- a/croaring/src/bitmap/imp.rs +++ b/croaring/src/bitmap/imp.rs @@ -1,8 +1,9 @@ +use crate::callback::CallbackWrapper; use crate::Bitset; use ffi::roaring_bitmap_t; use std::ffi::{c_void, CStr}; -use std::ops::{Bound, RangeBounds}; -use std::{mem, ptr}; +use std::ops::{Bound, ControlFlow, RangeBounds}; +use std::{mem, panic, ptr}; use super::serialization::{Deserializer, Serializer}; use super::{Bitmap, Statistics}; @@ -743,6 +744,49 @@ impl Bitmap { unsafe { ffi::roaring_bitmap_flip_inplace(&mut self.bitmap, start, end) } } + /// Iterate over the values in the bitmap in sorted order + /// + /// If `f` returns `Break`, iteration will stop and the value will be returned, + /// Otherwise, iteration continues. If `f` never returns break, `None` is returned after all values are visited. + /// + /// # Examples + /// + /// ``` + /// use croaring::Bitmap; + /// use std::ops::ControlFlow; + /// + /// let bitmap = Bitmap::of(&[1, 2, 3, 14, 20, 21, 100]); + /// let mut even_nums_under_50 = vec![]; + /// + /// let first_over_50 = bitmap.for_each(|value| { + /// if value > 50 { + /// return ControlFlow::Break(value); + /// } + /// if value % 2 == 0 { + /// even_nums_under_50.push(value); + /// } + /// ControlFlow::Continue(()) + /// }); + /// + /// assert_eq!(even_nums_under_50, vec![2, 14, 20]); + /// assert_eq!(first_over_50, ControlFlow::Break(100)); + /// ``` + #[inline] + pub fn for_each(&self, f: F) -> ControlFlow + where + F: FnMut(u32) -> ControlFlow, + { + let mut callback_wrapper = CallbackWrapper::new(f); + let (callback, context) = callback_wrapper.callback_and_ctx(); + unsafe { + ffi::roaring_iterate(&self.bitmap, Some(callback), context); + } + match callback_wrapper.result() { + Ok(cf) => cf, + Err(e) => panic::resume_unwind(e), + } + } + /// Returns a vector containing all of the integers stored in the Bitmap /// in sorted order. /// diff --git a/croaring/src/bitmap64/imp.rs b/croaring/src/bitmap64/imp.rs index 10d55ce..9a7e2ee 100644 --- a/croaring/src/bitmap64/imp.rs +++ b/croaring/src/bitmap64/imp.rs @@ -1,11 +1,11 @@ use crate::bitmap64::Bitmap64; use crate::bitmap64::{Deserializer, Serializer}; -use std::any::Any; +use crate::callback::CallbackWrapper; use std::collections::Bound; use std::ffi::CStr; use std::mem::MaybeUninit; use std::ops::{ControlFlow, RangeBounds}; -use std::panic::{self, AssertUnwindSafe}; +use std::panic; use std::ptr; use std::ptr::NonNull; @@ -910,43 +910,12 @@ impl Bitmap64 { where F: FnMut(u64) -> ControlFlow, { - struct State { - f: F, - result: Result, Box>, - } - - unsafe extern "C" fn callback(value: u64, arg: *mut std::ffi::c_void) -> bool - where - F: FnMut(u64) -> ControlFlow, - { - let state: &mut State = unsafe { &mut *arg.cast::>() }; - let mut f = AssertUnwindSafe(&mut state.f); - let result = panic::catch_unwind(move || f(value)); - match result { - Ok(ControlFlow::Continue(())) => true, - Ok(ControlFlow::Break(val)) => { - state.result = Ok(ControlFlow::Break(val)); - false - } - Err(e) => { - state.result = Err(e); - false - } - } - } - - let mut state = State { - f, - result: Ok(ControlFlow::Continue(())), - }; + let mut callback_wrapper = CallbackWrapper::new(f); + let (callback, context) = callback_wrapper.callback_and_ctx(); unsafe { - ffi::roaring64_bitmap_iterate( - self.raw.as_ptr(), - Some(callback::), - ptr::addr_of_mut!(state).cast(), - ); + ffi::roaring64_bitmap_iterate(self.raw.as_ptr(), Some(callback), context); } - match state.result { + match callback_wrapper.result() { Ok(cf) => cf, Err(e) => panic::resume_unwind(e), } diff --git a/croaring/src/callback.rs b/croaring/src/callback.rs new file mode 100644 index 0000000..7ce0434 --- /dev/null +++ b/croaring/src/callback.rs @@ -0,0 +1,56 @@ +use std::any::Any; +use std::ops::ControlFlow; +use std::panic::AssertUnwindSafe; +use std::{panic, ptr}; + +pub struct CallbackWrapper { + f: F, + result: Result, Box>, +} + +impl CallbackWrapper { + pub fn new(f: F) -> Self { + Self { + f, + result: Ok(ControlFlow::Continue(())), + } + } + + unsafe extern "C" fn raw_callback(value: I, arg: *mut std::ffi::c_void) -> bool + where + I: panic::UnwindSafe, + F: FnMut(I) -> ControlFlow, + { + let wrapper = &mut *(arg as *mut Self); + let mut f = AssertUnwindSafe(&mut wrapper.f); + let result = panic::catch_unwind(move || f(value)); + match result { + Ok(ControlFlow::Continue(())) => true, + Ok(cf @ ControlFlow::Break(_)) => { + wrapper.result = Ok(cf); + false + } + Err(err) => { + wrapper.result = Err(err); + false + } + } + } + + pub fn callback_and_ctx( + &mut self, + ) -> ( + unsafe extern "C" fn(I, *mut std::ffi::c_void) -> bool, + *mut std::ffi::c_void, + ) + where + I: panic::UnwindSafe, + F: FnMut(I) -> ControlFlow, + { + (Self::raw_callback::, ptr::addr_of_mut!(*self).cast()) + } + + pub fn result(self) -> Result, Box> { + self.result + } +} diff --git a/croaring/src/lib.rs b/croaring/src/lib.rs index 953a59e..52b6fe1 100644 --- a/croaring/src/lib.rs +++ b/croaring/src/lib.rs @@ -8,6 +8,7 @@ pub mod bitmap64; pub mod bitset; pub mod treemap; +mod callback; mod serialization; pub use serialization::*;