diff --git a/benches/comparison_kernels.rs b/benches/comparison_kernels.rs index bc938d43fe2..f42da77276d 100644 --- a/benches/comparison_kernels.rs +++ b/benches/comparison_kernels.rs @@ -19,28 +19,38 @@ extern crate criterion; use criterion::Criterion; -use arrow2::{compute::comparison::*, datatypes::DataType, types::NativeType}; +use arrow2::array::*; use arrow2::util::bench_util::*; -use arrow2::{array::*}; +use arrow2::{compute::comparison::*, datatypes::DataType, types::NativeType}; fn bench_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) where T: NativeType, { - compare(criterion::black_box(arr_a), criterion::black_box(arr_b), Operator::Eq).unwrap(); + compare( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + Operator::Eq, + ) + .unwrap(); } fn bench_eq_scalar(arr_a: &PrimitiveArray, value_b: T) where T: NativeType + std::cmp::PartialOrd, { - primtive_compare_scalar(criterion::black_box(arr_a), criterion::black_box(value_b), Operator::Eq).unwrap(); + primtive_compare_scalar( + criterion::black_box(arr_a), + criterion::black_box(value_b), + Operator::Eq, + ) + .unwrap(); } fn add_benchmark(c: &mut Criterion) { let size = 65536; let arr_a = create_primitive_array::(size, DataType::Float32, 0.0); - let arr_b = create_primitive_array::(size, DataType::Float32,0.0); + let arr_b = create_primitive_array::(size, DataType::Float32, 0.0); c.bench_function("eq Float32", |b| b.iter(|| bench_eq(&arr_a, &arr_b))); c.bench_function("eq scalar Float32", |b| { diff --git a/src/compute/comparison/primitive.rs b/src/compute/comparison/primitive.rs index 1923a3cafac..471445f0df9 100644 --- a/src/compute/comparison/primitive.rs +++ b/src/compute/comparison/primitive.rs @@ -15,8 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::error::{ArrowError, Result}; -use crate::{array::*, bitmap::Bitmap, types::NativeType}; +use crate::{array::*, types::NativeType}; +use crate::{ + bits, + buffer::MutableBuffer, + error::{ArrowError, Result}, +}; use super::{super::utils::combine_validities, Operator}; @@ -35,14 +39,44 @@ where let validity = combine_validities(lhs.validity(), rhs.validity()); - let values = lhs - .values() - .iter() - .zip(rhs.values()) - .map(|(lhs, rhs)| op(*lhs, *rhs)); - let values = unsafe { Bitmap::from_trusted_len_iter(values) }; - - Ok(BooleanArray::from_data(values, validity)) + let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8); + + let lhs_chunks_iter = lhs.values().chunks_exact(8); + let lhs_remainder = lhs_chunks_iter.remainder(); + let rhs_chunks_iter = rhs.values().chunks_exact(8); + let rhs_remainder = rhs_chunks_iter.remainder(); + + let chunks = lhs.len() / 8; + + values[..chunks] + .iter_mut() + .zip(lhs_chunks_iter) + .zip(rhs_chunks_iter) + .for_each(|((byte, lhs), rhs)| { + (0..8).for_each(|i| { + if op(lhs[i], rhs[i]) { + *byte = bits::set(*byte, i) + } + }); + }); + + if !lhs_remainder.is_empty() { + let last = &mut values[chunks]; + lhs_remainder + .iter() + .zip(rhs_remainder.iter()) + .enumerate() + .for_each(|(i, (lhs, rhs))| { + if op(*lhs, *rhs) { + *last = bits::set(*last, i) + } + }); + }; + + Ok(BooleanArray::from_data( + (values, lhs.len()).into(), + validity, + )) } /// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using @@ -54,10 +88,36 @@ where { let validity = lhs.validity().clone(); - let values = lhs.values().iter().map(|lhs| op(*lhs, rhs)); - let values = unsafe { Bitmap::from_trusted_len_iter(values) }; - - Ok(BooleanArray::from_data(values, validity)) + let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8); + + let lhs_chunks_iter = lhs.values().chunks_exact(8); + let lhs_remainder = lhs_chunks_iter.remainder(); + let chunks = lhs.len() / 8; + + values[..chunks] + .iter_mut() + .zip(lhs_chunks_iter) + .for_each(|(byte, lhs)| { + (0..8).for_each(|i| { + if op(lhs[i], rhs) { + *byte = bits::set(*byte, i) + } + }); + }); + + if !lhs_remainder.is_empty() { + let last = &mut values[chunks]; + lhs_remainder.iter().enumerate().for_each(|(i, lhs)| { + if op(*lhs, rhs) { + *last = bits::set(*last, i) + } + }); + }; + + Ok(BooleanArray::from_data( + (values, lhs.len()).into(), + validity, + )) } /// Perform `lhs == rhs` operation on two arrays.