Skip to content

Commit

Permalink
Add NEON backend for RawTable
Browse files Browse the repository at this point in the history
The core algorithm is based on the NEON support in [SwissTable], adapted
for the different control byte encodings used in hashbrown.

[SwissTable]: abseil/abseil-cpp@6481443
  • Loading branch information
Amanieu committed May 6, 2023
1 parent 59c0e15 commit 5355386
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 30 deletions.
52 changes: 32 additions & 20 deletions src/raw/bitmask.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::imp::{BitMaskWord, BITMASK_MASK, BITMASK_STRIDE};
use super::imp::{BitMaskWord, BITMASK_ITER_MASK, BITMASK_MASK, BITMASK_STRIDE};
#[cfg(feature = "nightly")]
use core::intrinsics;

Expand All @@ -8,11 +8,16 @@ use core::intrinsics;
/// The bit mask is arranged so that low-order bits represent lower memory
/// addresses for group match results.
///
/// For implementation reasons, the bits in the set may be sparsely packed, so
/// that there is only one bit-per-byte used (the high bit, 7). If this is the
/// For implementation reasons, the bits in the set may be sparsely packed with
/// groups of 8 bits representing one element. If any of these bits are non-zero
/// then this element is considered to true in the mask. If this is the
/// case, `BITMASK_STRIDE` will be 8 to indicate a divide-by-8 should be
/// performed on counts/indices to normalize this difference. `BITMASK_MASK` is
/// similarly a mask of all the actually-used bits.
///
/// To iterate over a bit mask, it must be converted to a form where only 1 bit
/// is set per element. This is done by applying `BITMASK_ITER_MASK` on the
/// mask bits.
#[derive(Copy, Clone)]
pub(crate) struct BitMask(pub(crate) BitMaskWord);

Expand All @@ -21,30 +26,18 @@ impl BitMask {
/// Returns a new `BitMask` with all bits inverted.
#[inline]
#[must_use]
#[allow(dead_code)]
pub(crate) fn invert(self) -> Self {
BitMask(self.0 ^ BITMASK_MASK)
}

/// Flip the bit in the mask for the entry at the given index.
///
/// Returns the bit's previous state.
#[inline]
#[allow(clippy::cast_ptr_alignment)]
#[cfg(feature = "raw")]
pub(crate) unsafe fn flip(&mut self, index: usize) -> bool {
// NOTE: The + BITMASK_STRIDE - 1 is to set the high bit.
let mask = 1 << (index * BITMASK_STRIDE + BITMASK_STRIDE - 1);
self.0 ^= mask;
// The bit was set if the bit is now 0.
self.0 & mask == 0
}

/// Returns a new `BitMask` with the lowest bit removed.
#[inline]
#[must_use]
pub(crate) fn remove_lowest_bit(self) -> Self {
fn remove_lowest_bit(self) -> Self {
BitMask(self.0 & (self.0 - 1))
}

/// Returns whether the `BitMask` has at least one set bit.
#[inline]
pub(crate) fn any_bit_set(self) -> bool {
Expand Down Expand Up @@ -102,13 +95,32 @@ impl IntoIterator for BitMask {

#[inline]
fn into_iter(self) -> BitMaskIter {
BitMaskIter(self)
// A BitMask only requires each element (group of bits) to be non-zero.
// However for iteration we need each element to only contain 1 bit.
BitMaskIter(BitMask(self.0 & BITMASK_ITER_MASK))
}
}

/// Iterator over the contents of a `BitMask`, returning the indices of set
/// bits.
pub(crate) struct BitMaskIter(BitMask);
#[derive(Copy, Clone)]
pub(crate) struct BitMaskIter(pub(crate) BitMask);

impl BitMaskIter {
/// Flip the bit in the mask for the entry at the given index.
///
/// Returns the bit's previous state.
#[inline]
#[allow(clippy::cast_ptr_alignment)]
#[cfg(feature = "raw")]
pub(crate) unsafe fn flip(&mut self, index: usize) -> bool {
// NOTE: The + BITMASK_STRIDE - 1 is to set the high bit.
let mask = 1 << (index * BITMASK_STRIDE + BITMASK_STRIDE - 1);
self.0 .0 ^= mask;
// The bit was set if the bit is now 0.
self.0 .0 & mask == 0
}
}

impl Iterator for BitMaskIter {
type Item = usize;
Expand Down
1 change: 1 addition & 0 deletions src/raw/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) const BITMASK_STRIDE: usize = 8;
// We only care about the highest bit of each byte for the mask.
#[allow(clippy::cast_possible_truncation, clippy::unnecessary_cast)]
pub(crate) const BITMASK_MASK: BitMaskWord = 0x8080_8080_8080_8080_u64 as GroupWord;
pub(crate) const BITMASK_ITER_MASK: BitMaskWord = !0;

/// Helper function to replicate a byte across a `GroupWord`.
#[inline]
Expand Down
27 changes: 17 additions & 10 deletions src/raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ cfg_if! {
))] {
mod sse2;
use sse2 as imp;
} else if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] {
mod neon;
use neon as imp;
} else {
#[path = "generic.rs"]
mod generic;
use generic as imp;
}
Expand All @@ -37,7 +39,7 @@ pub(crate) use self::alloc::{do_alloc, Allocator, Global};

mod bitmask;

use self::bitmask::{BitMask, BitMaskIter};
use self::bitmask::BitMaskIter;
use self::imp::Group;

// Branch prediction hint. This is currently only available on nightly but it
Expand Down Expand Up @@ -2716,7 +2718,7 @@ impl<T, A: Allocator + Clone> IntoIterator for RawTable<T, A> {
pub(crate) struct RawIterRange<T> {
// Mask of full buckets in the current group. Bits are cleared from this
// mask as each element is processed.
current_group: BitMask,
current_group: BitMaskIter,

// Pointer to the buckets for the current group.
data: Bucket<T>,
Expand Down Expand Up @@ -2744,7 +2746,7 @@ impl<T> RawIterRange<T> {
let next_ctrl = ctrl.add(Group::WIDTH);

Self {
current_group,
current_group: current_group.into_iter(),
data,
next_ctrl,
end,
Expand Down Expand Up @@ -2801,8 +2803,7 @@ impl<T> RawIterRange<T> {
#[cfg_attr(feature = "inline-more", inline)]
unsafe fn next_impl<const DO_CHECK_PTR_RANGE: bool>(&mut self) -> Option<Bucket<T>> {
loop {
if let Some(index) = self.current_group.lowest_set_bit() {
self.current_group = self.current_group.remove_lowest_bit();
if let Some(index) = self.current_group.next() {
return Some(self.data.next_n(index));
}

Expand All @@ -2815,7 +2816,7 @@ impl<T> RawIterRange<T> {
// than the group size where the trailing control bytes are all
// EMPTY. On larger tables self.end is guaranteed to be aligned
// to the group size (since tables are power-of-two sized).
self.current_group = Group::load_aligned(self.next_ctrl).match_full();
self.current_group = Group::load_aligned(self.next_ctrl).match_full().into_iter();
self.data = self.data.next_n(Group::WIDTH);
self.next_ctrl = self.next_ctrl.add(Group::WIDTH);
}
Expand Down Expand Up @@ -2956,7 +2957,7 @@ impl<T> RawIter<T> {
// - Otherwise, update the iterator cached group so that it won't
// yield a to-be-removed bucket, or _will_ yield a to-be-added bucket.
// We'll also need to update the item count accordingly.
if let Some(index) = self.iter.current_group.lowest_set_bit() {
if let Some(index) = self.iter.current_group.0.lowest_set_bit() {
let next_bucket = self.iter.data.next_n(index);
if b.as_ptr() > next_bucket.as_ptr() {
// The toggled bucket is "before" the bucket the iterator would yield next. We
Expand Down Expand Up @@ -2989,10 +2990,16 @@ impl<T> RawIter<T> {
if cfg!(debug_assertions) {
if b.as_ptr() == next_bucket.as_ptr() {
// The removed bucket should no longer be next
debug_assert_ne!(self.iter.current_group.lowest_set_bit(), Some(index));
debug_assert_ne!(
self.iter.current_group.0.lowest_set_bit(),
Some(index)
);
} else {
// We should not have changed what bucket comes next.
debug_assert_eq!(self.iter.current_group.lowest_set_bit(), Some(index));
debug_assert_eq!(
self.iter.current_group.0.lowest_set_bit(),
Some(index)
);
}
}
}
Expand Down
122 changes: 122 additions & 0 deletions src/raw/neon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use super::bitmask::BitMask;
use super::EMPTY;
use core::arch::aarch64 as neon;
use core::mem;

pub(crate) type BitMaskWord = u64;
pub(crate) const BITMASK_STRIDE: usize = 8;
pub(crate) const BITMASK_MASK: BitMaskWord = !0;
pub(crate) const BITMASK_ITER_MASK: BitMaskWord = 0x8080_8080_8080_8080;

/// Abstraction over a group of control bytes which can be scanned in
/// parallel.
///
/// This implementation uses a 64-bit NEON value.
#[derive(Copy, Clone)]
pub(crate) struct Group(neon::uint8x8_t);

#[allow(clippy::use_self)]
impl Group {
/// Number of bytes in the group.
pub(crate) const WIDTH: usize = mem::size_of::<Self>();

/// Returns a full group of empty bytes, suitable for use as the initial
/// value for an empty hash table.
///
/// This is guaranteed to be aligned to the group size.
#[inline]
pub(crate) const fn static_empty() -> &'static [u8; Group::WIDTH] {
#[repr(C)]
struct AlignedBytes {
_align: [Group; 0],
bytes: [u8; Group::WIDTH],
}
const ALIGNED_BYTES: AlignedBytes = AlignedBytes {
_align: [],
bytes: [EMPTY; Group::WIDTH],
};
&ALIGNED_BYTES.bytes
}

/// Loads a group of bytes starting at the given address.
#[inline]
#[allow(clippy::cast_ptr_alignment)] // unaligned load
pub(crate) unsafe fn load(ptr: *const u8) -> Self {
Group(neon::vld1_u8(ptr))
}

/// Loads a group of bytes starting at the given address, which must be
/// aligned to `mem::align_of::<Group>()`.
#[inline]
#[allow(clippy::cast_ptr_alignment)]
pub(crate) unsafe fn load_aligned(ptr: *const u8) -> Self {
// FIXME: use align_offset once it stabilizes
debug_assert_eq!(ptr as usize & (mem::align_of::<Self>() - 1), 0);
Group(neon::vld1_u8(ptr))
}

/// Stores the group of bytes to the given address, which must be
/// aligned to `mem::align_of::<Group>()`.
#[inline]
#[allow(clippy::cast_ptr_alignment)]
pub(crate) unsafe fn store_aligned(self, ptr: *mut u8) {
// FIXME: use align_offset once it stabilizes
debug_assert_eq!(ptr as usize & (mem::align_of::<Self>() - 1), 0);
neon::vst1_u8(ptr, self.0);
}

/// Returns a `BitMask` indicating all bytes in the group which *may*
/// have the given value.
#[inline]
pub(crate) fn match_byte(self, byte: u8) -> BitMask {
unsafe {
let cmp = neon::vceq_u8(self.0, neon::vdup_n_u8(byte));
BitMask(neon::vget_lane_u64(neon::vreinterpret_u64_u8(cmp), 0))
}
}

/// Returns a `BitMask` indicating all bytes in the group which are
/// `EMPTY`.
#[inline]
pub(crate) fn match_empty(self) -> BitMask {
self.match_byte(EMPTY)
}

/// Returns a `BitMask` indicating all bytes in the group which are
/// `EMPTY` or `DELETED`.
#[inline]
pub(crate) fn match_empty_or_deleted(self) -> BitMask {
unsafe {
let cmp = neon::vcltz_s8(neon::vreinterpret_s8_u8(self.0));
BitMask(neon::vget_lane_u64(neon::vreinterpret_u64_u8(cmp), 0))
}
}

/// Returns a `BitMask` indicating all bytes in the group which are full.
#[inline]
pub(crate) fn match_full(self) -> BitMask {
unsafe {
let cmp = neon::vcgez_s8(neon::vreinterpret_s8_u8(self.0));
BitMask(neon::vget_lane_u64(neon::vreinterpret_u64_u8(cmp), 0))
}
}

/// Performs the following transformation on all bytes in the group:
/// - `EMPTY => EMPTY`
/// - `DELETED => EMPTY`
/// - `FULL => DELETED`
#[inline]
pub(crate) fn convert_special_to_empty_and_full_to_deleted(self) -> Self {
// Map high_bit = 1 (EMPTY or DELETED) to 1111_1111
// and high_bit = 0 (FULL) to 1000_0000
//
// Here's this logic expanded to concrete values:
// let special = 0 > byte = 1111_1111 (true) or 0000_0000 (false)
// 1111_1111 | 1000_0000 = 1111_1111
// 0000_0000 | 1000_0000 = 1000_0000
unsafe {
let special = neon::vcltz_s8(neon::vreinterpret_s8_u8(self.0));
Group(neon::vorr_u8(special, neon::vdup_n_u8(0x80)))
}
}
}
1 change: 1 addition & 0 deletions src/raw/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use core::arch::x86_64 as x86;
pub(crate) type BitMaskWord = u16;
pub(crate) const BITMASK_STRIDE: usize = 1;
pub(crate) const BITMASK_MASK: BitMaskWord = 0xffff;
pub(crate) const BITMASK_ITER_MASK: BitMaskWord = !0;

/// Abstraction over a group of control bytes which can be scanned in
/// parallel.
Expand Down

0 comments on commit 5355386

Please sign in to comment.