diff --git a/rust/arrow/benches/equal.rs b/rust/arrow/benches/equal.rs index 6783662e186..a6d8b55dc45 100644 --- a/rust/arrow/benches/equal.rs +++ b/rust/arrow/benches/equal.rs @@ -50,6 +50,23 @@ fn create_string_array(size: usize, with_nulls: bool) -> ArrayRef { Arc::new(builder.finish()) } +fn create_bool_array(size: usize, with_nulls: bool) -> ArrayRef { + // use random numbers to avoid spurious compiler optimizations wrt to branching + let mut rng = seedable_rng(); + let mut builder = BooleanBuilder::new(size); + + for _ in 0..size { + let val = rng.gen::(); + let is_even = val % 2 == 0; + if with_nulls && is_even { + builder.append_null().unwrap(); + } else { + builder.append_value(is_even).unwrap(); + } + } + Arc::new(builder.finish()) +} + fn create_array(size: usize, with_nulls: bool) -> ArrayRef { // use random numbers to avoid spurious compiler optimizations wrt to branching let mut rng = seedable_rng(); @@ -83,6 +100,14 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("equal_string_nulls_512", |b| { b.iter(|| bench_equal(&arr_a_nulls)) }); + + let arr_a = create_bool_array(512, false); + c.bench_function("equal_bool_512", |b| b.iter(|| bench_equal(&arr_a))); + + let arr_a_nulls = create_bool_array(512, true); + c.bench_function("equal_bool_nulls_512", |b| { + b.iter(|| bench_equal(&arr_a_nulls)) + }); } criterion_group!(benches, add_benchmark); diff --git a/rust/arrow/src/array/equal/boolean.rs b/rust/arrow/src/array/equal/boolean.rs index 88bd080ba53..004ae8bbfe0 100644 --- a/rust/arrow/src/array/equal/boolean.rs +++ b/rust/arrow/src/array/equal/boolean.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. +use super::utils::{equal_bits, equal_len}; use crate::array::ArrayData; -use super::utils::equal_bits; - pub(super) fn boolean_equal( lhs: &ArrayData, rhs: &ArrayData, @@ -29,21 +28,103 @@ pub(super) fn boolean_equal( let lhs_values = lhs.buffers()[0].as_slice(); let rhs_values = rhs.buffers()[0].as_slice(); - // TODO: we can do this more efficiently if all values are not-null - (0..len).all(|i| { - let lhs_pos = lhs_start + i; - let rhs_pos = rhs_start + i; - let lhs_is_null = lhs.is_null(lhs_pos); - let rhs_is_null = rhs.is_null(rhs_pos); + // Try optimize for zero null counts and same align format. + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let mut lhs_align_left = 0; + let mut lhs_prefix_bits = 0; + if lhs_start > 0 { + lhs_align_left = lhs_start / 8_usize; + if lhs_start % 8_usize > 0 { + lhs_align_left += 1_usize; + } + lhs_prefix_bits = lhs_align_left * 8_usize - lhs_start; + } + + let mut rhs_prefix_bits = 0; + if rhs_start > 0 { + let mut align = rhs_start / 8_usize; + if rhs_start % 8_usize > 0 { + align += 1; + } + rhs_prefix_bits = align * 8_usize - rhs_start; + } + + // `lhs_prefix_len == lhs_prefix_len` means same align format: + // prefix_bits | aligned_bytes | suffix_bits + if lhs_prefix_bits == rhs_prefix_bits { + // Compare prefix bit slices. + if lhs_prefix_bits > 0 + && !equal_bits( + lhs_values, + rhs_values, + lhs.offset() + lhs_start, + rhs.offset() + rhs_start, + lhs_prefix_bits, + ) + { + return false; + } + + let lhs_align_right = (lhs_start + len) / 8_usize; + let align_bytes_len = lhs_align_right - lhs_align_left; + let align_bits_len = align_bytes_len * 8_usize; + let suffix_len = len - align_bits_len; + + // Compare suffix bit slices. + if suffix_len > 0 + && !equal_bits( + lhs_values, + rhs_values, + lhs.offset() + lhs_start + align_bits_len, + rhs.offset() + rhs_start + align_bits_len, + suffix_len, + ) + { + return false; + } - lhs_is_null - || (lhs_is_null == rhs_is_null) - && equal_bits( + // Compare byte slices. + if align_bytes_len > 0 + && !equal_len( lhs_values, rhs_values, - lhs_pos + lhs.offset(), - rhs_pos + rhs.offset(), - 1, + lhs_align_left, + lhs_align_left, + align_bytes_len, ) - }) + { + return false; + } + + true + } else { + equal_bits( + lhs_values, + rhs_values, + lhs.offset() + lhs_start, + rhs.offset() + rhs_start, + len, + ) + } + } else { + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + if lhs_is_null != rhs_is_null { + return false; + } + if lhs_is_null && rhs_is_null { + return true; + } + equal_bits( + lhs_values, + rhs_values, + lhs_pos + lhs.offset(), + rhs_pos + rhs.offset(), + 1, + ) + }) + } }