diff --git a/rust/arrow/src/compute/kernels/comparison.rs b/rust/arrow/src/compute/kernels/comparison.rs index 37b9c2ec2d9..8f81c3bcf49 100644 --- a/rust/arrow/src/compute/kernels/comparison.rs +++ b/rust/arrow/src/compute/kernels/comparison.rs @@ -27,9 +27,11 @@ use std::collections::HashMap; use std::sync::Arc; use crate::array::*; +use crate::buffer::{Buffer, MutableBuffer}; use crate::compute::util::combine_option_bitmap; use crate::datatypes::{ArrowNumericType, BooleanType, DataType}; use crate::error::{ArrowError, Result}; +use crate::util::bit_util; /// Helper function to perform boolean lambda function on values from two arrays, this /// version does not attempt to use SIMD. @@ -258,7 +260,6 @@ where T: ArrowNumericType, F: Fn(T::Simd, T::Simd) -> T::SimdMask, { - use crate::buffer::MutableBuffer; use std::io::Write; use std::mem; @@ -320,7 +321,6 @@ where T: ArrowNumericType, F: Fn(T::Simd, T::Simd) -> T::SimdMask, { - use crate::buffer::MutableBuffer; use std::io::Write; use std::mem; @@ -555,11 +555,134 @@ where compare_op_scalar!(left, right, |a, b| a >= b) } +/// Checks if a `GenericListArray` contains a value in the `PrimitiveArray` +pub fn contains( + left: &PrimitiveArray, + right: &GenericListArray, +) -> Result +where + T: ArrowNumericType, + OffsetSize: OffsetSizeTrait, +{ + let left_len = left.len(); + if left_len != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + + let num_bytes = bit_util::ceil(left_len, 8); + + let not_both_null_bit_buffer = + match combine_option_bitmap(left.data_ref(), right.data_ref(), left_len)? { + Some(buff) => buff, + None => new_all_set_buffer(num_bytes), + }; + let not_both_null_bitmap = not_both_null_bit_buffer.data(); + + let mut bool_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + let bool_slice = bool_buf.data_mut(); + + // if both array slots are valid, check if list contains primitive + for i in 0..left_len { + if bit_util::get_bit(not_both_null_bitmap, i) { + let list = right.value(i); + let list = list.as_any().downcast_ref::>().unwrap(); + + for j in 0..list.len() { + if list.is_valid(j) && (left.value(i) == list.value(j)) { + bit_util::set_bit(bool_slice, i); + continue; + } + } + } + } + + let data = ArrayData::new( + DataType::Boolean, + left.len(), + None, + None, + 0, + vec![bool_buf.freeze()], + vec![], + ); + Ok(PrimitiveArray::::from(Arc::new(data))) +} + +/// Checks if a `GenericListArray` contains a value in the `GenericStringArray` +pub fn contains_utf8( + left: &GenericStringArray, + right: &ListArray, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let left_len = left.len(); + if left_len != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + + let num_bytes = bit_util::ceil(left_len, 8); + + let not_both_null_bit_buffer = + match combine_option_bitmap(left.data_ref(), right.data_ref(), left_len)? { + Some(buff) => buff, + None => new_all_set_buffer(num_bytes), + }; + let not_both_null_bitmap = not_both_null_bit_buffer.data(); + + let mut bool_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + let bool_slice = bool_buf.data_mut(); + + for i in 0..left_len { + // contains(null, null) = false + if bit_util::get_bit(not_both_null_bitmap, i) { + let list = right.value(i); + let list = list + .as_any() + .downcast_ref::>() + .unwrap(); + + for j in 0..list.len() { + if list.is_valid(j) && (left.value(i) == list.value(j)) { + bit_util::set_bit(bool_slice, i); + continue; + } + } + } + } + + let data = ArrayData::new( + DataType::Boolean, + left.len(), + None, + None, + 0, + vec![bool_buf.freeze()], + vec![], + ); + Ok(PrimitiveArray::::from(Arc::new(data))) +} + +// create a buffer and fill it with valid bits +#[inline] +fn new_all_set_buffer(len: usize) -> Buffer { + let buffer = MutableBuffer::new(len); + let buffer = buffer.with_bitset(len, true); + + buffer.freeze() +} + #[cfg(test)] mod tests { use super::*; use crate::array::Int32Array; - use crate::datatypes::Int8Type; + use crate::datatypes::{Int8Type, ToByteSlice}; #[test] fn test_primitive_array_eq() { @@ -807,6 +930,111 @@ mod tests { ); } + // Expected behaviour: + // contains(1, [1, 2, null]) = true + // contains(3, [1, 2, null]) = false + // contains(null, [1, 2, null]) = false + // contains(null, null) = false + #[test] + fn test_contains() { + let value_data = Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + None, + Some(7), + ]) + .data(); + let value_offsets = Buffer::from(&[0i64, 3, 6, 6, 9].to_byte_slice()); + let list_data_type = DataType::LargeList(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type) + .len(4) + .add_buffer(value_offsets) + .null_count(1) + .add_child_data(value_data) + .null_bit_buffer(Buffer::from([0b00001011])) + .build(); + + // [[0, 1, 2], [3, 4, 5], null, [6, null, 7]] + let list_array = LargeListArray::from(list_data); + + let nulls = Int32Array::from(vec![None, None, None, None]); + let nulls_result = contains(&nulls, &list_array).unwrap(); + assert_eq!( + nulls_result + .as_any() + .downcast_ref::() + .unwrap(), + &BooleanArray::from(vec![false, false, false, false]), + ); + + let values = Int32Array::from(vec![Some(0), Some(0), Some(0), Some(0)]); + let values_result = contains(&values, &list_array).unwrap(); + assert_eq!( + values_result + .as_any() + .downcast_ref::() + .unwrap(), + &BooleanArray::from(vec![true, false, false, false]), + ); + } + + // Expected behaviour: + // contains("ab", ["ab", "cd", null]) = true + // contains("ef", ["ab", "cd", null]) = false + // contains(null, ["ab", "cd", null]) = false + // contains(null, null) = false + #[test] + fn test_contains_utf8() { + let values_builder = StringBuilder::new(10); + let mut builder = ListBuilder::new(values_builder); + + builder.values().append_value("Lorem").unwrap(); + builder.values().append_value("ipsum").unwrap(); + builder.values().append_null().unwrap(); + builder.append(true).unwrap(); + builder.values().append_value("sit").unwrap(); + builder.values().append_value("amet").unwrap(); + builder.values().append_value("Lorem").unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.values().append_value("ipsum").unwrap(); + builder.append(true).unwrap(); + + // [["Lorem", "ipsum", null], ["sit", "amet", "Lorem"], null, ["ipsum"]] + // value_offsets = [0, 3, 6, 6] + let list_array = builder.finish(); + + let nulls = StringArray::from(vec![None, None, None, None]); + let nulls_result = contains_utf8(&nulls, &list_array).unwrap(); + assert_eq!( + nulls_result + .as_any() + .downcast_ref::() + .unwrap(), + &BooleanArray::from(vec![false, false, false, false]), + ); + + let values = StringArray::from(vec![ + Some("Lorem"), + Some("Lorem"), + Some("Lorem"), + Some("Lorem"), + ]); + let values_result = contains_utf8(&values, &list_array).unwrap(); + assert_eq!( + values_result + .as_any() + .downcast_ref::() + .unwrap(), + &BooleanArray::from(vec![true, true, false, false]), + ); + } + macro_rules! test_utf8 { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] diff --git a/rust/arrow/src/compute/util.rs b/rust/arrow/src/compute/util.rs index 85b6296ecd7..8ba35fbf8a7 100644 --- a/rust/arrow/src/compute/util.rs +++ b/rust/arrow/src/compute/util.rs @@ -20,7 +20,7 @@ use crate::array::*; #[cfg(feature = "simd")] use crate::bitmap::Bitmap; -use crate::buffer::{buffer_bin_and, Buffer}; +use crate::buffer::{buffer_bin_and, buffer_bin_or, Buffer}; #[cfg(feature = "simd")] use crate::datatypes::*; use crate::error::{ArrowError, Result}; @@ -71,6 +71,47 @@ pub(super) fn combine_option_bitmap( } } +/// Compares the null bitmaps of two arrays using a bitwise `or` operation. +/// +/// This function is useful when implementing operations on higher level arrays. +pub(super) fn compare_option_bitmap( + left_data: &ArrayDataRef, + right_data: &ArrayDataRef, + len_in_bits: usize, +) -> Result> { + let left_offset_in_bits = left_data.offset(); + let right_offset_in_bits = right_data.offset(); + + let left = left_data.null_buffer(); + let right = right_data.null_buffer(); + + if (left.is_some() && left_offset_in_bits % 8 != 0) + || (right.is_some() && right_offset_in_bits % 8 != 0) + { + return Err(ArrowError::ComputeError( + "Cannot compare option bitmaps that are not byte-aligned.".to_string(), + )); + } + + let left_offset = left_offset_in_bits / 8; + let right_offset = right_offset_in_bits / 8; + + match left { + None => match right { + None => Ok(None), + Some(r) => Ok(Some(r.slice(right_offset))), + }, + Some(l) => match right { + None => Ok(Some(l.slice(left_offset))), + + Some(r) => { + let len = ceil(len_in_bits, 8); + Ok(Some(buffer_bin_or(&l, left_offset, &r, right_offset, len))) + } + }, + } +} + /// Takes/filters a list array's inner data using the offsets of the list array. /// /// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns @@ -208,6 +249,8 @@ mod tests { let none_bitmap = make_data_with_null_bit_buffer(8, 0, None); let some_bitmap = make_data_with_null_bit_buffer(8, 0, Some(Buffer::from([0b01001010]))); + let inverse_bitmap = + make_data_with_null_bit_buffer(8, 0, Some(Buffer::from([0b10110101]))); assert_eq!( None, combine_option_bitmap(&none_bitmap, &none_bitmap, 8).unwrap() @@ -224,6 +267,39 @@ mod tests { Some(Buffer::from([0b01001010])), combine_option_bitmap(&some_bitmap, &some_bitmap, 8,).unwrap() ); + assert_eq!( + Some(Buffer::from([0b0])), + combine_option_bitmap(&some_bitmap, &inverse_bitmap, 8,).unwrap() + ); + } + + #[test] + fn test_compare_option_bitmap() { + let none_bitmap = make_data_with_null_bit_buffer(8, 0, None); + let some_bitmap = + make_data_with_null_bit_buffer(8, 0, Some(Buffer::from([0b01001010]))); + let inverse_bitmap = + make_data_with_null_bit_buffer(8, 0, Some(Buffer::from([0b10110101]))); + assert_eq!( + None, + compare_option_bitmap(&none_bitmap, &none_bitmap, 8).unwrap() + ); + assert_eq!( + Some(Buffer::from([0b01001010])), + compare_option_bitmap(&some_bitmap, &none_bitmap, 8).unwrap() + ); + assert_eq!( + Some(Buffer::from([0b01001010])), + compare_option_bitmap(&none_bitmap, &some_bitmap, 8,).unwrap() + ); + assert_eq!( + Some(Buffer::from([0b01001010])), + compare_option_bitmap(&some_bitmap, &some_bitmap, 8,).unwrap() + ); + assert_eq!( + Some(Buffer::from([0b11111111])), + compare_option_bitmap(&some_bitmap, &inverse_bitmap, 8,).unwrap() + ); } #[test]