Skip to content

Commit

Permalink
array_has with eq kernel (apache#12125)
Browse files Browse the repository at this point in the history
* first draft

Signed-off-by: jayzhan211 <[email protected]>

* avoid row converter for string

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* string view

Signed-off-by: jayzhan211 <[email protected]>

* trigger ci

Signed-off-by: jayzhan211 <[email protected]>

* refactor

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* array_has_all

Signed-off-by: jayzhan211 <[email protected]>

* array_has_any

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* rm unused import

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* backup

Signed-off-by: jayzhan211 <[email protected]>

* add bench

Signed-off-by: jayzhan211 <[email protected]>

* new approach

Signed-off-by: jayzhan211 <[email protected]>

* general scalar wins

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* revert query

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* reuse slice and fix typo

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Aug 24, 2024
1 parent a58416c commit 14d6404
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 74 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/functions-nested/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { workspace = true }
datafusion-functions-aggregate = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
itertools = { workspace = true, features = ["use_std"] }
log = { workspace = true }
paste = "1.0.14"
Expand Down
188 changes: 117 additions & 71 deletions datafusion/functions-nested/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait};
use arrow::datatypes::DataType;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::GenericListArray;
use arrow_array::{Datum, GenericListArray, Scalar};
use datafusion_common::cast::as_generic_list_array;
use datafusion_common::utils::string_utils::string_array_to_vec;
use datafusion_common::{exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Operator, ScalarUDFImpl, Signature, Volatility};

use datafusion_physical_expr_common::datum::compare_op_for_nested;
use itertools::Itertools;

use crate::utils::make_scalar_function;
Expand Down Expand Up @@ -95,25 +96,132 @@ impl ScalarUDFImpl for ArrayHas {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(array_has_inner)(args)
// Always return null if the second argumet is null
// i.e. array_has(array, null) -> null
if let ColumnarValue::Scalar(s) = &args[1] {
if s.is_null() {
return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
}
}

// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();

let result = match args[1] {
ColumnarValue::Array(_) => {
let args = ColumnarValue::values_to_arrays(args)?;
array_has_inner_for_array(&args[0], &args[1])
}
ColumnarValue::Scalar(_) => {
let haystack = args[0].to_owned().into_array(1)?;
let needle = args[1].to_owned().into_array(1)?;
let needle = Scalar::new(needle);
array_has_inner_for_scalar(&haystack, &needle)
}
};

if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

fn array_has_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::List(_) => array_has_dispatch::<i32>(&args[0], &args[1]),
DataType::LargeList(_) => array_has_dispatch::<i64>(&args[0], &args[1]),
fn array_has_inner_for_scalar(
haystack: &ArrayRef,
needle: &dyn Datum,
) -> Result<ArrayRef> {
match haystack.data_type() {
DataType::List(_) => array_has_dispatch_for_scalar::<i32>(haystack, needle),
DataType::LargeList(_) => array_has_dispatch_for_scalar::<i64>(haystack, needle),
_ => exec_err!(
"array_has does not support type '{:?}'.",
args[0].data_type()
haystack.data_type()
),
}
}

fn array_has_inner_for_array(haystack: &ArrayRef, needle: &ArrayRef) -> Result<ArrayRef> {
match haystack.data_type() {
DataType::List(_) => array_has_dispatch_for_array::<i32>(haystack, needle),
DataType::LargeList(_) => array_has_dispatch_for_array::<i64>(haystack, needle),
_ => exec_err!(
"array_has does not support type '{:?}'.",
haystack.data_type()
),
}
}

fn array_has_dispatch_for_array<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let haystack = as_generic_list_array::<O>(haystack)?;
let mut boolean_builder = BooleanArray::builder(haystack.len());

for (i, arr) in haystack.iter().enumerate() {
if arr.is_none() || needle.is_null(i) {
boolean_builder.append_null();
continue;
}
let arr = arr.unwrap();
let needle_row = Scalar::new(needle.slice(i, 1));
let eq_array = compare_op_for_nested(Operator::Eq, &arr, &needle_row)?;
let is_contained = eq_array.true_count() > 0;
boolean_builder.append_value(is_contained)
}

Ok(Arc::new(boolean_builder.finish()))
}

fn array_has_dispatch_for_scalar<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &dyn Datum,
) -> Result<ArrayRef> {
let haystack = as_generic_list_array::<O>(haystack)?;
let values = haystack.values();
let offsets = haystack.value_offsets();
// If first argument is empty list (second argument is non-null), return false
// i.e. array_has([], non-null element) -> false
if values.len() == 0 {
return Ok(Arc::new(BooleanArray::from(vec![Some(false)])));
}
let eq_array = compare_op_for_nested(Operator::Eq, values, needle)?;
let mut final_contained = vec![None; haystack.len()];
for (i, offset) in offsets.windows(2).enumerate() {
let start = offset[0].to_usize().unwrap();
let end = offset[1].to_usize().unwrap();
let length = end - start;
// For non-nested list, length is 0 for null
if length == 0 {
continue;
}
let sliced_array = eq_array.slice(start, length);
// For nested list, check number of nulls
if sliced_array.null_count() == length {
continue;
}
final_contained[i] = Some(sliced_array.true_count() > 0);
}

Ok(Arc::new(BooleanArray::from(final_contained)))
}

fn array_has_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::List(_) => {
Expand Down Expand Up @@ -245,19 +353,6 @@ enum ComparisonType {
Any,
}

fn array_has_dispatch<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let haystack = as_generic_list_array::<O>(haystack)?;
match needle.data_type() {
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
array_has_string_internal::<O>(haystack, needle)
}
_ => general_array_has::<O>(haystack, needle),
}
}

fn array_has_all_and_any_dispatch<O: OffsetSizeTrait>(
haystack: &ArrayRef,
needle: &ArrayRef,
Expand All @@ -273,55 +368,6 @@ fn array_has_all_and_any_dispatch<O: OffsetSizeTrait>(
}
}

fn array_has_string_internal<O: OffsetSizeTrait>(
haystack: &GenericListArray<O>,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let mut boolean_builder = BooleanArray::builder(haystack.len());
for (arr, element) in haystack.iter().zip(string_array_to_vec(needle).into_iter()) {
match (arr, element) {
(Some(arr), Some(element)) => {
boolean_builder.append_value(
string_array_to_vec(&arr)
.into_iter()
.flatten()
.any(|x| x == element),
);
}
(_, _) => {
boolean_builder.append_null();
}
}
}

Ok(Arc::new(boolean_builder.finish()))
}

fn general_array_has<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
needle: &ArrayRef,
) -> Result<ArrayRef> {
let mut boolean_builder = BooleanArray::builder(array.len());
let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
let sub_arr_values = converter.convert_columns(&[Arc::clone(needle)])?;

for (row_idx, arr) in array.iter().enumerate() {
if let Some(arr) = arr {
let arr_values = converter.convert_columns(&[arr])?;
boolean_builder.append_value(
arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
);
} else {
boolean_builder.append_null();
}
}

Ok(Arc::new(boolean_builder.finish()))
}

// String comparison for array_has_all and array_has_any
fn array_has_all_and_any_string_internal<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
Expand Down
14 changes: 11 additions & 3 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4972,11 +4972,19 @@ NULL 1 1

## array_has/array_has_all/array_has_any

query BB
# If lhs is empty, return false
query B
select array_has([], 1);
----
false

# If rhs is Null, we returns Null
query BBB
select array_has([], null),
array_has([1, 2, 3], null);
array_has([1, 2, 3], null),
array_has([null, 1], null);
----
false false
NULL NULL NULL

#TODO: array_has_all and array_has_any cannot handle NULL
#query BBBB
Expand Down

0 comments on commit 14d6404

Please sign in to comment.