Skip to content
341 changes: 337 additions & 4 deletions arrow-select/src/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

use crate::filter::{SlicesIterator, prep_null_mask_filter};
use arrow_array::cast::AsArray;
use arrow_array::types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type};
use arrow_array::types::{
BinaryType, BinaryViewType, ByteArrayType, ByteViewType, LargeBinaryType, LargeUtf8Type,
StringViewType, Utf8Type,
};
use arrow_array::*;
use arrow_buffer::{
BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, OffsetBufferBuilder,
ScalarBuffer,
ScalarBuffer, ToByteSlice,
};
use arrow_data::ArrayData;
use arrow_data::transform::MutableArrayData;
use arrow_data::{ArrayData, ByteView};
use arrow_schema::{ArrowError, DataType};
use std::fmt::{Debug, Formatter};
use std::hash::Hash;
Expand Down Expand Up @@ -284,7 +287,12 @@ impl ScalarZipper {
DataType::LargeBinary => {
Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
},
// TODO: Handle Utf8View https://github.com/apache/arrow-rs/issues/8724
DataType::Utf8View => {
Arc::new(ByteViewScalarImpl::<StringViewType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
},
DataType::BinaryView => {
Arc::new(ByteViewScalarImpl::<BinaryViewType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
},
_ => {
Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl>
},
Expand Down Expand Up @@ -657,6 +665,177 @@ fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer {
}
}

struct ByteViewScalarImpl<T: ByteViewType> {
truthy_view: Option<u128>,
truthy_buffers: Vec<Buffer>,
falsy_view: Option<u128>,
falsy_buffers: Vec<Buffer>,
phantom: PhantomData<T>,
}

impl<T: ByteViewType> ByteViewScalarImpl<T> {
fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self {
let (truthy_view, truthy_buffers) = Self::get_value_from_scalar(truthy);
let (falsy_view, falsy_buffers) = Self::get_value_from_scalar(falsy);
Self {
truthy_view,
truthy_buffers,
falsy_view,
falsy_buffers,
phantom: PhantomData,
}
}

fn get_value_from_scalar(scalar: &dyn Array) -> (Option<u128>, Vec<Buffer>) {
if scalar.is_null(0) {
(None, vec![])
} else {
let (views, buffers, _) = scalar.as_byte_view::<T>().clone().into_parts();
(views.first().copied(), buffers)
}
}

fn get_views_for_single_non_nullable(
predicate: BooleanBuffer,
value: u128,
buffers: Vec<Buffer>,
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
let number_of_true = predicate.count_set_bits();
let number_of_values = predicate.len();

// Fast path for all nulls
if number_of_true == 0 {
// All values are null
return (
vec![0; number_of_values].into(),
vec![],
Some(NullBuffer::new_null(number_of_values)),
);
}
let bytes = vec![value; number_of_values];

// If value is true and we want to handle the TRUTHY case, the null buffer will have 1 (meaning not null)
// If value is false and we want to handle the FALSY case, the null buffer will have 0 (meaning null)
let nulls = NullBuffer::new(predicate);
(bytes.into(), buffers, Some(nulls))
}

fn get_views_for_non_nullable(
predicate: BooleanBuffer,
result_len: usize,
truthy_view: u128,
truthy_buffers: Vec<Buffer>,
falsy_view: u128,
falsy_buffers: Vec<Buffer>,
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
let true_count = predicate.count_set_bits();
match true_count {
0 => {
// all values are falsy
(vec![falsy_view; result_len].into(), falsy_buffers, None)
}
n if n == predicate.len() => {
// all values are truthy
(vec![truthy_view; result_len].into(), truthy_buffers, None)
}
_ => {
let true_count = predicate.count_set_bits();
let mut buffers: Vec<Buffer> = truthy_buffers.to_vec();

// If the falsy buffers are empty, we can use the falsy view as it is, because the value
// is completely inlined. Otherwise, we have non-inlined values in the buffer, and we need
// to recalculate the falsy view
let view_falsy = if falsy_buffers.is_empty() {
falsy_view
} else {
let byte_view_falsy = ByteView::from(falsy_view);
let new_index_falsy_buffers =
buffers.len() as u32 + byte_view_falsy.buffer_index;
buffers.extend(falsy_buffers);
let byte_view_falsy =
byte_view_falsy.with_buffer_index(new_index_falsy_buffers);
byte_view_falsy.as_u128()
};

let total_number_of_bytes = true_count * 16 + (predicate.len() - true_count) * 16;
let mut mutable = MutableBuffer::new(total_number_of_bytes);
let mut filled = 0;

SlicesIterator::from(&predicate).for_each(|(start, end)| {
if start > filled {
let false_repeat_count = start - filled;
mutable
.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count);
}
let true_repeat_count = end - start;
mutable.repeat_slice_n_times(truthy_view.to_byte_slice(), true_repeat_count);
filled = end;
});

if filled < predicate.len() {
let false_repeat_count = predicate.len() - filled;
mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count);
}

let bytes = Buffer::from(mutable);
(bytes.into(), buffers, None)
}
}
}
}

impl<T: ByteViewType> Debug for ByteViewScalarImpl<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ByteViewScalarImpl")
.field("truthy", &self.truthy_view)
.field("falsy", &self.falsy_view)
.finish()
}
}

impl<T: ByteViewType> ZipImpl for ByteViewScalarImpl<T> {
fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
let result_len = predicate.len();
// Nulls are treated as false
let predicate = maybe_prep_null_mask_filter(predicate);

let (views, buffers, nulls) = match (self.truthy_view, self.falsy_view) {
(Some(truthy), Some(falsy)) => Self::get_views_for_non_nullable(
predicate,
result_len,
truthy,
self.truthy_buffers.clone(),
falsy,
self.falsy_buffers.clone(),
),
(Some(truthy), None) => Self::get_views_for_single_non_nullable(
predicate,
truthy,
self.truthy_buffers.clone(),
),
(None, Some(falsy)) => {
let predicate = predicate.not();
Self::get_views_for_single_non_nullable(
predicate,
falsy,
self.falsy_buffers.clone(),
)
}
(None, None) => {
// All values are null
(
vec![0; result_len].into(),
vec![],
Some(NullBuffer::new_null(result_len)),
)
}
};

let result = unsafe { GenericByteViewArray::<T>::new_unchecked(views, buffers, nulls) };
Ok(Arc::new(result))
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -1222,4 +1401,158 @@ mod test {
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));

let mask = BooleanArray::from(vec![true, false, true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("hello"),
Some("world"),
Some("hello"),
Some("world"),
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_binary_array_view() {
let scalar_truthy = Scalar::new(BinaryViewArray::from_iter_values(vec![b"hello"]));
let scalar_falsy = Scalar::new(BinaryViewArray::from_iter_values(vec![b"world"]));

let mask = BooleanArray::from(vec![true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_byte_view();
let expected = BinaryViewArray::from_iter_values(vec![b"hello", b"world"]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_with_nulls() {
let scalar_truthy = Scalar::new(StringViewArray::from_iter_values(["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));

let mask = BooleanArray::from(vec![true, true, false, false, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
let expected = StringViewArray::from_iter(vec![
Some("hello"),
Some("hello"),
None,
None,
Some("hello"),
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_all_true_null() {
let scalar_truthy = Scalar::new(StringViewArray::new_null(1));
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
let mask = BooleanArray::from(vec![true, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
let expected = StringViewArray::from_iter(vec![None::<String>, None]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_all_false_null() {
let scalar_truthy = Scalar::new(StringViewArray::new_null(1));
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
let mask = BooleanArray::from(vec![false, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
let expected = StringViewArray::from_iter(vec![None::<String>, None]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_string_array_view_all_true() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));

let mask = BooleanArray::from(vec![true, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![Some("hello"), Some("hello")]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_string_array_view_all_false() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));

let mask = BooleanArray::from(vec![false, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![Some("world"), Some("world")]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_large_strings() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));

let mask = BooleanArray::from(vec![true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("longer than 12 bytes"),
Some("another longer than 12 bytes"),
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_large_short_strings() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));

let mask = BooleanArray::from(vec![true, false, true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("hello"),
Some("longer than 12 bytes"),
Some("hello"),
Some("longer than 12 bytes"),
]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_strings_array_view_large_all_true() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));

let mask = BooleanArray::from(vec![true, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("longer than 12 bytes"),
Some("longer than 12 bytes"),
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_large_all_false() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));

let mask = BooleanArray::from(vec![false, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("another longer than 12 bytes"),
Some("another longer than 12 bytes"),
]);
assert_eq!(actual, &expected);
}
}
Loading