Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rust/arrow/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ where

let remainder_bytes = ceil(left_chunks.remainder_len(), 8);
let rem = op(left_chunks.remainder_bits(), right_chunks.remainder_bits());
// we are counting its starting from the least significant bit, to to_le_bytes should be correct
let rem = &rem.to_le_bytes()[0..remainder_bytes];
result
.write_all(rem)
Expand Down Expand Up @@ -448,6 +449,7 @@ where

let remainder_bytes = ceil(left_chunks.remainder_len(), 8);
let rem = op(left_chunks.remainder_bits());
// we are counting its starting from the least significant bit, to to_le_bytes should be correct
let rem = &rem.to_le_bytes()[0..remainder_bytes];
result
.write_all(rem)
Expand Down
110 changes: 72 additions & 38 deletions rust/arrow/src/util/bit_chunk_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ use std::fmt::Debug;
pub struct BitChunks<'a> {
buffer: &'a Buffer,
raw_data: *const u8,
offset: usize,
/// offset inside a byte, guaranteed to be between 0 and 7 (inclusive)
bit_offset: usize,
/// number of complete u64 chunks
chunk_len: usize,
/// number of remaining bits, guaranteed to be between 0 and 63 (inclusive)
remainder_len: usize,
}

Expand All @@ -32,19 +35,19 @@ impl<'a> BitChunks<'a> {
assert!(ceil(offset + len, 8) <= buffer.len() * 8);

let byte_offset = offset / 8;
let offset = offset % 8;
let bit_offset = offset % 8;

let raw_data = unsafe { buffer.raw_data().add(byte_offset) };

let chunk_bits = 64;
let chunk_bits = 8 * std::mem::size_of::<u64>();

let chunk_len = len / chunk_bits;
let remainder_len = len & (chunk_bits - 1);

BitChunks::<'a> {
buffer: &buffer,
raw_data,
offset,
bit_offset,
chunk_len,
remainder_len,
}
Expand All @@ -55,48 +58,52 @@ impl<'a> BitChunks<'a> {
pub struct BitChunkIterator<'a> {
buffer: &'a Buffer,
raw_data: *const u8,
offset: usize,
bit_offset: usize,
chunk_len: usize,
index: usize,
}

impl<'a> BitChunks<'a> {
/// Returns the number of remaining bits, guaranteed to be between 0 and 63 (inclusive)
#[inline]
pub const fn remainder_len(&self) -> usize {
self.remainder_len
}

/// Returns the bitmask of remaining bits
#[inline]
pub fn remainder_bits(&self) -> u64 {
let bit_len = self.remainder_len;
if bit_len == 0 {
0
} else {
let byte_len = ceil(bit_len, 8);

let mut bits = 0;
for i in 0..byte_len {
let byte = unsafe {
std::ptr::read(
self.raw_data
.add(self.chunk_len * std::mem::size_of::<u64>() + i),
)
};
bits |= (byte as u64) << (i * 8);
}
let bit_offset = self.bit_offset;
// number of bytes to read
// might be one more than sizeof(u64) if the offset is in the middle of a byte
let byte_len = ceil(bit_len + bit_offset, 8);
// pointer to remainder bytes after all complete chunks
let base = unsafe {
self.raw_data
.add(self.chunk_len * std::mem::size_of::<u64>())
};

let offset = self.offset as u64;
let mut bits = unsafe { std::ptr::read(base) } as u64 >> bit_offset;
for i in 1..byte_len {
let byte = unsafe { std::ptr::read(base.add(i)) };
bits |= (byte as u64) << (i * 8 - bit_offset);
}

(bits >> offset) & ((1 << bit_len) - 1)
bits & ((1 << bit_len) - 1)
}
}

/// Returns an iterator over chunks of 64 bits represented as an u64
#[inline]
pub const fn iter(&self) -> BitChunkIterator<'a> {
BitChunkIterator::<'a> {
buffer: self.buffer,
raw_data: self.raw_data,
offset: self.offset,
bit_offset: self.bit_offset,
chunk_len: self.chunk_len,
index: 0,
}
Expand All @@ -117,31 +124,30 @@ impl Iterator for BitChunkIterator<'_> {

#[inline]
fn next(&mut self) -> Option<u64> {
if self.index >= self.chunk_len {
let index = self.index;
if index >= self.chunk_len {
return None;
}

// cast to *const u64 should be fine since we are using read_unaligned
// cast to *const u64 should be fine since we are using read_unaligned below
#[allow(clippy::cast_ptr_alignment)]
let current = unsafe {
std::ptr::read_unaligned((self.raw_data as *const u64).add(self.index))
};
let raw_data = self.raw_data as *const u64;

// bit-packed buffers are stored starting with the least-significant byte first
// so when reading as u64 on a big-endian machine, the bytes need to be swapped
let current = unsafe { std::ptr::read_unaligned(raw_data.add(index)).to_le() };

let combined = if self.offset == 0 {
let combined = if self.bit_offset == 0 {
current
} else {
// cast to *const u64 should be fine since we are using read_unaligned
#[allow(clippy::cast_ptr_alignment)]
let next = unsafe {
std::ptr::read_unaligned(
(self.raw_data as *const u64).add(self.index + 1),
)
};
current >> self.offset
| (next & ((1 << self.offset) - 1)) << (64 - self.offset)
let next =
unsafe { std::ptr::read_unaligned(raw_data.add(index + 1)).to_le() };

current >> self.bit_offset
| (next & ((1 << self.bit_offset) - 1)) << (64 - self.bit_offset)
};

self.index += 1;
self.index = index + 1;

Some(combined)
}
Expand Down Expand Up @@ -192,7 +198,6 @@ mod tests {

let result = bitchunks.into_iter().collect::<Vec<_>>();

//assert_eq!(vec![0b00010000, 0b00100000, 0b01000000, 0b10000000, 0b00000000, 0b00000001, 0b00000010, 0b11110100], result);
assert_eq!(
vec![0b1111010000000010000000010000000010000000010000000010000000010000],
result
Expand All @@ -214,10 +219,39 @@ mod tests {

let result = bitchunks.into_iter().collect::<Vec<_>>();

//assert_eq!(vec![0b00010000, 0b00100000, 0b01000000, 0b10000000, 0b00000000, 0b00000001, 0b00000010, 0b11110100], result);
assert_eq!(
vec![0b1111010000000010000000010000000010000000010000000010000000010000],
result
);
}

#[test]
fn test_iter_unaligned_remainder_bits_across_bytes() {
let input: &[u8] = &[0b00111111, 0b11111100];
let buffer: Buffer = Buffer::from(input);

// remainder contains bits from both bytes
// result should be the highest 2 bits from first byte followed by lowest 5 bits of second bytes
let bitchunks = buffer.bit_chunks(6, 7);

assert_eq!(7, bitchunks.remainder_len());
assert_eq!(0b1110000, bitchunks.remainder_bits());
}

#[test]
fn test_iter_unaligned_remainder_bits_large() {
let input: &[u8] = &[
0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000,
0b11111111, 0b00000000, 0b11111111,
];
let buffer: Buffer = Buffer::from(input);

let bitchunks = buffer.bit_chunks(2, 63);

assert_eq!(63, bitchunks.remainder_len());
assert_eq!(
0b1000000_00111111_11000000_00111111_11000000_00111111_11000000_00111111,
bitchunks.remainder_bits()
);
}
}