Skip to content

Commit

Permalink
feat: array_contains (#6618)
Browse files Browse the repository at this point in the history
* feat: array_contains

* feat: regen.sh

* docs: array_contains

* fix: merge

* Update docs/source/user-guide/sql/scalar_functions.md

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
izveigor and alamb committed Jun 27, 2023
1 parent 4f2933f commit 1dd1fbd
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 25 deletions.
36 changes: 36 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,39 @@ query ?
select make_array(x, y) from foo2;
----
[1.0, 1]

# array_contains scalar function #1
query BBB rowsort
select array_contains(make_array(1, 2, 3), make_array(1, 1, 2, 3)), array_contains([1, 2, 3], [1, 1, 2]), array_contains([1, 2, 3], [2, 1, 3, 1]);
----
true true true

# array_contains scalar function #2
query BB rowsort
select array_contains([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 3]]), array_contains([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 4]);
----
true true

# array_contains scalar function #3
query BBB rowsort
select array_contains(make_array(1, 2, 3), make_array(1, 2, 3, 4)), array_contains([1, 2, 3], [1, 1, 4]), array_contains([1, 2, 3], [2, 1, 3, 4]);
----
false false false

# array_contains scalar function #4
query BB rowsort
select array_contains([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 5]]), array_contains([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 5]);
----
false false

# array_contains scalar function #5
query BB rowsort
select array_contains([true, true, false, true, false], [true, false, false]), array_contains([true, false, true], [true, true]);
----
true true

# array_contains scalar function #6
query BB rowsort
select array_contains(make_array(true, true, true), make_array(false, false)), array_contains([false, false, false], [true, true]);
----
false false
6 changes: 6 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ pub enum BuiltinScalarFunction {
ArrayAppend,
/// array_concat
ArrayConcat,
/// array_contains
ArrayContains,
/// array_dims
ArrayDims,
/// array_fill
Expand Down Expand Up @@ -319,6 +321,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArrayAppend => Volatility::Immutable,
BuiltinScalarFunction::ArrayConcat => Volatility::Immutable,
BuiltinScalarFunction::ArrayContains => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
BuiltinScalarFunction::ArrayFill => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
Expand Down Expand Up @@ -460,6 +463,7 @@ impl BuiltinScalarFunction {
"The {self} function can only accept fixed size list as the args."
))),
},
BuiltinScalarFunction::ArrayContains => Ok(Boolean),
BuiltinScalarFunction::ArrayDims => Ok(UInt8),
BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new(
"item",
Expand Down Expand Up @@ -741,6 +745,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayContains => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayLength => {
Expand Down Expand Up @@ -1166,6 +1171,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
// array functions
BuiltinScalarFunction::ArrayAppend => &["array_append"],
BuiltinScalarFunction::ArrayConcat => &["array_concat"],
BuiltinScalarFunction::ArrayContains => &["array_contains"],
BuiltinScalarFunction::ArrayDims => &["array_dims"],
BuiltinScalarFunction::ArrayFill => &["array_fill"],
BuiltinScalarFunction::ArrayLength => &["array_length"],
Expand Down
7 changes: 7 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,13 @@ scalar_expr!(
"appends an element to the end of an array."
);
nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays.");
scalar_expr!(
ArrayContains,
array_contains,
first_array second_array,
"returns true, if each element of the second array appe
aring in the first array, otherwise false."
);
scalar_expr!(
ArrayDims,
array_dims,
Expand Down
126 changes: 124 additions & 2 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use datafusion_common::cast::as_list_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use itertools::Itertools;
use std::sync::Arc;

macro_rules! downcast_vec {
Expand Down Expand Up @@ -1070,6 +1071,70 @@ pub fn array_ndims(args: &[ColumnarValue]) -> Result<ColumnarValue> {
]))))
}

macro_rules! contains {
($FIRST_ARRAY:expr, $SECOND_ARRAY:expr, $ARRAY_TYPE:ident) => {{
let first_array = downcast_arg!($FIRST_ARRAY, $ARRAY_TYPE);
let second_array = downcast_arg!($SECOND_ARRAY, $ARRAY_TYPE);
let mut res = true;
for x in second_array.values().iter().dedup() {
if !first_array.values().contains(x) {
res = false;
}
}

res
}};
}

/// Array_contains SQL function
pub fn array_contains(args: &[ArrayRef]) -> Result<ArrayRef> {
fn concat_inner_lists(arg: ArrayRef) -> Result<ArrayRef> {
match arg.data_type() {
DataType::List(field) => match field.data_type() {
DataType::List(..) => {
concat_inner_lists(array_concat(&[as_list_array(&arg)?
.values()
.clone()])?)
}
_ => Ok(as_list_array(&arg)?.values().clone()),
},
data_type => Err(DataFusionError::NotImplemented(format!(
"Array is not type '{data_type:?}'."
))),
}
}

let concat_first_array = concat_inner_lists(args[0].clone())?.clone();
let concat_second_array = concat_inner_lists(args[1].clone())?.clone();

let res = match (concat_first_array.data_type(), concat_second_array.data_type()) {
(DataType::Utf8, DataType::Utf8) => contains!(concat_first_array, concat_second_array, StringArray),
(DataType::LargeUtf8, DataType::LargeUtf8) => contains!(concat_first_array, concat_second_array, LargeStringArray),
(DataType::Boolean, DataType::Boolean) => {
let first_array = downcast_arg!(concat_first_array, BooleanArray);
let second_array = downcast_arg!(concat_second_array, BooleanArray);
compute::bool_or(first_array) == compute::bool_or(second_array)
}
(DataType::Float32, DataType::Float32) => contains!(concat_first_array, concat_second_array, Float32Array),
(DataType::Float64, DataType::Float64) => contains!(concat_first_array, concat_second_array, Float64Array),
(DataType::Int8, DataType::Int8) => contains!(concat_first_array, concat_second_array, Int8Array),
(DataType::Int16, DataType::Int16) => contains!(concat_first_array, concat_second_array, Int16Array),
(DataType::Int32, DataType::Int32) => contains!(concat_first_array, concat_second_array, Int32Array),
(DataType::Int64, DataType::Int64) => contains!(concat_first_array, concat_second_array, Int64Array),
(DataType::UInt8, DataType::UInt8) => contains!(concat_first_array, concat_second_array, UInt8Array),
(DataType::UInt16, DataType::UInt16) => contains!(concat_first_array, concat_second_array, UInt16Array),
(DataType::UInt32, DataType::UInt32) => contains!(concat_first_array, concat_second_array, UInt32Array),
(DataType::UInt64, DataType::UInt64) => contains!(concat_first_array, concat_second_array, UInt64Array),
(first_array_data_type, second_array_data_type) => {
return Err(DataFusionError::NotImplemented(format!(
"Array_contains is not implemented for types '{first_array_data_type:?}' and '{second_array_data_type:?}'."
)))
}
};

Ok(Arc::new(BooleanArray::from(vec![res])))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1588,7 +1653,7 @@ mod tests {

#[test]
fn test_array_ndims() {
// array_ndims([1, 2]) = 1
// array_ndims([1, 2, 3, 4]) = 1
let list_array = return_array();

let array = array_ndims(&[list_array])
Expand All @@ -1602,7 +1667,7 @@ mod tests {

#[test]
fn test_nested_array_ndims() {
// array_ndims([[1, 2], [3, 4]]) = 2
// array_ndims([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2
let list_array = return_nested_array();

let array = array_ndims(&[list_array])
Expand All @@ -1614,6 +1679,63 @@ mod tests {
assert_eq!(result, &UInt8Array::from(vec![2]));
}

#[test]
fn test_array_contains() {
// array_contains([1, 2, 3, 4], array_append([1, 2, 3, 4], 3)) = t
let first_array = return_array().into_array(1);
let second_array = array_append(&[
first_array.clone(),
Arc::new(Int64Array::from(vec![Some(3)])),
])
.expect("failed to initialize function array_contains");

let arr = array_contains(&[first_array.clone(), second_array])
.expect("failed to initialize function array_contains");
let result = as_boolean_array(&arr);

assert_eq!(result, &BooleanArray::from(vec![true]));

// array_contains([1, 2, 3, 4], array_append([1, 2, 3, 4], 5)) = f
let second_array = array_append(&[
first_array.clone(),
Arc::new(Int64Array::from(vec![Some(5)])),
])
.expect("failed to initialize function array_contains");

let arr = array_contains(&[first_array.clone(), second_array])
.expect("failed to initialize function array_contains");
let result = as_boolean_array(&arr);

assert_eq!(result, &BooleanArray::from(vec![false]));
}

#[test]
fn test_nested_array_contains() {
// array_contains([[1, 2, 3, 4], [5, 6, 7, 8]], array_append([1, 2, 3, 4], 3)) = t
let first_array = return_nested_array().into_array(1);
let array = return_array().into_array(1);
let second_array =
array_append(&[array.clone(), Arc::new(Int64Array::from(vec![Some(3)]))])
.expect("failed to initialize function array_contains");

let arr = array_contains(&[first_array.clone(), second_array])
.expect("failed to initialize function array_contains");
let result = as_boolean_array(&arr);

assert_eq!(result, &BooleanArray::from(vec![true]));

// array_contains([[1, 2, 3, 4], [5, 6, 7, 8]], array_append([1, 2, 3, 4], 9)) = f
let second_array =
array_append(&[array.clone(), Arc::new(Int64Array::from(vec![Some(9)]))])
.expect("failed to initialize function array_contains");

let arr = array_contains(&[first_array.clone(), second_array])
.expect("failed to initialize function array_contains");
let result = as_boolean_array(&arr);

assert_eq!(result, &BooleanArray::from(vec![false]));
}

fn return_array() -> ColumnarValue {
let args = [
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayConcat => {
Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args))
}
BuiltinScalarFunction::ArrayContains => {
Arc::new(|args| make_scalar_function(array_expressions::array_contains)(args))
}
BuiltinScalarFunction::ArrayDims => Arc::new(array_expressions::array_dims),
BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill),
BuiltinScalarFunction::ArrayLength => Arc::new(array_expressions::array_length),
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ enum ScalarFunction {
ArrayToString = 97;
Cardinality = 98;
TrimArray = 99;
ArrayContains = 100;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 11 additions & 6 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ use datafusion_common::{
};
use datafusion_expr::expr::Placeholder;
use datafusion_expr::{
abs, acos, acosh, array, array_append, array_concat, array_dims, array_fill,
array_length, array_ndims, array_position, array_positions, array_prepend,
array_remove, array_replace, array_to_string, ascii, asin, asinh, atan, atan2, atanh,
bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce,
concat_expr, concat_ws_expr, cos, cosh, date_bin, date_part, date_trunc, degrees,
digest, exp,
abs, acos, acosh, array, array_append, array_concat, array_contains, array_dims,
array_fill, array_length, array_ndims, array_position, array_positions,
array_prepend, array_remove, array_replace, array_to_string, ascii, asin, asinh,
atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length,
chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, date_bin, date_part,
date_trunc, degrees, digest, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
Expand Down Expand Up @@ -450,6 +450,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ToTimestamp => Self::ToTimestamp,
ScalarFunction::ArrayAppend => Self::ArrayAppend,
ScalarFunction::ArrayConcat => Self::ArrayConcat,
ScalarFunction::ArrayContains => Self::ArrayContains,
ScalarFunction::ArrayDims => Self::ArrayDims,
ScalarFunction::ArrayFill => Self::ArrayFill,
ScalarFunction::ArrayLength => Self::ArrayLength,
Expand Down Expand Up @@ -1192,6 +1193,10 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::ArrayContains => Ok(array_contains(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ArrayFill => Ok(array_fill(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp,
BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend,
BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat,
BuiltinScalarFunction::ArrayContains => Self::ArrayContains,
BuiltinScalarFunction::ArrayDims => Self::ArrayDims,
BuiltinScalarFunction::ArrayFill => Self::ArrayFill,
BuiltinScalarFunction::ArrayLength => Self::ArrayLength,
Expand Down
Loading

0 comments on commit 1dd1fbd

Please sign in to comment.