diff --git a/arrow-select/src/zip.rs b/arrow-select/src/zip.rs index e45b817dc6e8..6be034fca23d 100644 --- a/arrow-select/src/zip.rs +++ b/arrow-select/src/zip.rs @@ -19,14 +19,17 @@ use crate::filter::{SlicesIterator, prep_null_mask_filter}; use arrow_array::cast::AsArray; -use arrow_array::types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type}; +use arrow_array::types::{ + BinaryType, BinaryViewType, ByteArrayType, ByteViewType, LargeBinaryType, LargeUtf8Type, + StringViewType, Utf8Type, +}; use arrow_array::*; use arrow_buffer::{ BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, OffsetBufferBuilder, - ScalarBuffer, + ScalarBuffer, ToByteSlice, }; -use arrow_data::ArrayData; use arrow_data::transform::MutableArrayData; +use arrow_data::{ArrayData, ByteView}; use arrow_schema::{ArrowError, DataType}; use std::fmt::{Debug, Formatter}; use std::hash::Hash; @@ -284,7 +287,12 @@ impl ScalarZipper { DataType::LargeBinary => { Arc::new(BytesScalarImpl::::new(truthy, falsy)) as Arc }, - // TODO: Handle Utf8View https://github.com/apache/arrow-rs/issues/8724 + DataType::Utf8View => { + Arc::new(ByteViewScalarImpl::::new(truthy, falsy)) as Arc + }, + DataType::BinaryView => { + Arc::new(ByteViewScalarImpl::::new(truthy, falsy)) as Arc + }, _ => { Arc::new(FallbackImpl::new(truthy, falsy)) as Arc }, @@ -657,6 +665,177 @@ fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer { } } +struct ByteViewScalarImpl { + truthy_view: Option, + truthy_buffers: Vec, + falsy_view: Option, + falsy_buffers: Vec, + phantom: PhantomData, +} + +impl ByteViewScalarImpl { + fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self { + let (truthy_view, truthy_buffers) = Self::get_value_from_scalar(truthy); + let (falsy_view, falsy_buffers) = Self::get_value_from_scalar(falsy); + Self { + truthy_view, + truthy_buffers, + falsy_view, + falsy_buffers, + phantom: PhantomData, + } + } + + fn get_value_from_scalar(scalar: &dyn Array) -> (Option, Vec) { + if scalar.is_null(0) { + (None, vec![]) + } else { + let (views, buffers, _) = scalar.as_byte_view::().clone().into_parts(); + (views.first().copied(), buffers) + } + } + + fn get_views_for_single_non_nullable( + predicate: BooleanBuffer, + value: u128, + buffers: Vec, + ) -> (ScalarBuffer, Vec, Option) { + let number_of_true = predicate.count_set_bits(); + let number_of_values = predicate.len(); + + // Fast path for all nulls + if number_of_true == 0 { + // All values are null + return ( + vec![0; number_of_values].into(), + vec![], + Some(NullBuffer::new_null(number_of_values)), + ); + } + let bytes = vec![value; number_of_values]; + + // If value is true and we want to handle the TRUTHY case, the null buffer will have 1 (meaning not null) + // If value is false and we want to handle the FALSY case, the null buffer will have 0 (meaning null) + let nulls = NullBuffer::new(predicate); + (bytes.into(), buffers, Some(nulls)) + } + + fn get_views_for_non_nullable( + predicate: BooleanBuffer, + result_len: usize, + truthy_view: u128, + truthy_buffers: Vec, + falsy_view: u128, + falsy_buffers: Vec, + ) -> (ScalarBuffer, Vec, Option) { + let true_count = predicate.count_set_bits(); + match true_count { + 0 => { + // all values are falsy + (vec![falsy_view; result_len].into(), falsy_buffers, None) + } + n if n == predicate.len() => { + // all values are truthy + (vec![truthy_view; result_len].into(), truthy_buffers, None) + } + _ => { + let true_count = predicate.count_set_bits(); + let mut buffers: Vec = truthy_buffers.to_vec(); + + // If the falsy buffers are empty, we can use the falsy view as it is, because the value + // is completely inlined. Otherwise, we have non-inlined values in the buffer, and we need + // to recalculate the falsy view + let view_falsy = if falsy_buffers.is_empty() { + falsy_view + } else { + let byte_view_falsy = ByteView::from(falsy_view); + let new_index_falsy_buffers = + buffers.len() as u32 + byte_view_falsy.buffer_index; + buffers.extend(falsy_buffers); + let byte_view_falsy = + byte_view_falsy.with_buffer_index(new_index_falsy_buffers); + byte_view_falsy.as_u128() + }; + + let total_number_of_bytes = true_count * 16 + (predicate.len() - true_count) * 16; + let mut mutable = MutableBuffer::new(total_number_of_bytes); + let mut filled = 0; + + SlicesIterator::from(&predicate).for_each(|(start, end)| { + if start > filled { + let false_repeat_count = start - filled; + mutable + .repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count); + } + let true_repeat_count = end - start; + mutable.repeat_slice_n_times(truthy_view.to_byte_slice(), true_repeat_count); + filled = end; + }); + + if filled < predicate.len() { + let false_repeat_count = predicate.len() - filled; + mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count); + } + + let bytes = Buffer::from(mutable); + (bytes.into(), buffers, None) + } + } + } +} + +impl Debug for ByteViewScalarImpl { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ByteViewScalarImpl") + .field("truthy", &self.truthy_view) + .field("falsy", &self.falsy_view) + .finish() + } +} + +impl ZipImpl for ByteViewScalarImpl { + fn create_output(&self, predicate: &BooleanArray) -> Result { + let result_len = predicate.len(); + // Nulls are treated as false + let predicate = maybe_prep_null_mask_filter(predicate); + + let (views, buffers, nulls) = match (self.truthy_view, self.falsy_view) { + (Some(truthy), Some(falsy)) => Self::get_views_for_non_nullable( + predicate, + result_len, + truthy, + self.truthy_buffers.clone(), + falsy, + self.falsy_buffers.clone(), + ), + (Some(truthy), None) => Self::get_views_for_single_non_nullable( + predicate, + truthy, + self.truthy_buffers.clone(), + ), + (None, Some(falsy)) => { + let predicate = predicate.not(); + Self::get_views_for_single_non_nullable( + predicate, + falsy, + self.falsy_buffers.clone(), + ) + } + (None, None) => { + // All values are null + ( + vec![0; result_len].into(), + vec![], + Some(NullBuffer::new_null(result_len)), + ) + } + }; + + let result = unsafe { GenericByteViewArray::::new_unchecked(views, buffers, nulls) }; + Ok(Arc::new(result)) + } +} + #[cfg(test)] mod test { use super::*; @@ -1222,4 +1401,158 @@ mod test { ]); assert_eq!(actual, &expected); } + + #[test] + fn test_zip_kernel_scalar_strings_array_view() { + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"])); + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"])); + + let mask = BooleanArray::from(vec![true, false, true, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_string_view(); + let expected = StringViewArray::from(vec![ + Some("hello"), + Some("world"), + Some("hello"), + Some("world"), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_binary_array_view() { + let scalar_truthy = Scalar::new(BinaryViewArray::from_iter_values(vec![b"hello"])); + let scalar_falsy = Scalar::new(BinaryViewArray::from_iter_values(vec![b"world"])); + + let mask = BooleanArray::from(vec![true, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_byte_view(); + let expected = BinaryViewArray::from_iter_values(vec![b"hello", b"world"]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_strings_array_view_with_nulls() { + let scalar_truthy = Scalar::new(StringViewArray::from_iter_values(["hello"])); + let scalar_falsy = Scalar::new(StringViewArray::new_null(1)); + + let mask = BooleanArray::from(vec![true, true, false, false, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = StringViewArray::from_iter(vec![ + Some("hello"), + Some("hello"), + None, + None, + Some("hello"), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_strings_array_view_all_true_null() { + let scalar_truthy = Scalar::new(StringViewArray::new_null(1)); + let scalar_falsy = Scalar::new(StringViewArray::new_null(1)); + let mask = BooleanArray::from(vec![true, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = StringViewArray::from_iter(vec![None::, None]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_strings_array_view_all_false_null() { + let scalar_truthy = Scalar::new(StringViewArray::new_null(1)); + let scalar_falsy = Scalar::new(StringViewArray::new_null(1)); + let mask = BooleanArray::from(vec![false, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_any().downcast_ref::().unwrap(); + let expected = StringViewArray::from_iter(vec![None::, None]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_string_array_view_all_true() { + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"])); + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"])); + + let mask = BooleanArray::from(vec![true, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_string_view(); + let expected = StringViewArray::from(vec![Some("hello"), Some("hello")]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_string_array_view_all_false() { + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"])); + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"])); + + let mask = BooleanArray::from(vec![false, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_string_view(); + let expected = StringViewArray::from(vec![Some("world"), Some("world")]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_strings_large_strings() { + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"])); + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"])); + + let mask = BooleanArray::from(vec![true, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_string_view(); + let expected = StringViewArray::from(vec![ + Some("longer than 12 bytes"), + Some("another longer than 12 bytes"), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_strings_array_view_large_short_strings() { + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"])); + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"])); + + let mask = BooleanArray::from(vec![true, false, true, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_string_view(); + let expected = StringViewArray::from(vec![ + Some("hello"), + Some("longer than 12 bytes"), + Some("hello"), + Some("longer than 12 bytes"), + ]); + assert_eq!(actual, &expected); + } + #[test] + fn test_zip_kernel_scalar_strings_array_view_large_all_true() { + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"])); + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"])); + + let mask = BooleanArray::from(vec![true, true]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_string_view(); + let expected = StringViewArray::from(vec![ + Some("longer than 12 bytes"), + Some("longer than 12 bytes"), + ]); + assert_eq!(actual, &expected); + } + + #[test] + fn test_zip_kernel_scalar_strings_array_view_large_all_false() { + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"])); + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"])); + + let mask = BooleanArray::from(vec![false, false]); + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); + let actual = out.as_string_view(); + let expected = StringViewArray::from(vec![ + Some("another longer than 12 bytes"), + Some("another longer than 12 bytes"), + ]); + assert_eq!(actual, &expected); + } }