Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support LargeList for array_has, array_has_all and array_has_any #8322

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 90 additions & 53 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1774,82 +1774,119 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// Array_has SQL function
pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
let array = as_list_array(&args[0])?;
let element = &args[1];
/// Represents the type of comparison for array_has.
#[derive(Debug, PartialEq)]
enum ComparisonType {
// array_has_all
All,
// array_has_any
Any,
// array_has
Single,
}

fn general_array_has_dispatch<O: OffsetSizeTrait>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

array: &ArrayRef,
sub_array: &ArrayRef,
comparison_type: ComparisonType,
) -> Result<ArrayRef> {
let array = if comparison_type == ComparisonType::Single {
let arr = as_generic_list_array::<O>(array)?;
check_datatypes("array_has", &[arr.values(), sub_array])?;
arr
} else {
check_datatypes("array_has", &[array, sub_array])?;
as_generic_list_array::<O>(array)?
};

check_datatypes("array_has", &[array.values(), element])?;
let mut boolean_builder = BooleanArray::builder(array.len());

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
let r_values = converter.convert_columns(&[element.clone()])?;
for (row_idx, arr) in array.iter().enumerate() {
if let Some(arr) = arr {

let element = sub_array.clone();
let sub_array = if comparison_type != ComparisonType::Single {
as_generic_list_array::<O>(sub_array)?
} else {
array
};

for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let res = arr_values
.iter()
.dedup()
.any(|x| x == r_values.row(row_idx));
let sub_arr_values = if comparison_type != ComparisonType::Single {
converter.convert_columns(&[sub_arr])?
} else {
converter.convert_columns(&[element.clone()])?
};

let mut res = match comparison_type {
ComparisonType::All => sub_arr_values
.iter()
.dedup()
.all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Any => sub_arr_values
.iter()
.dedup()
.any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Single => arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
};

if comparison_type == ComparisonType::Any {
res |= res;
}

boolean_builder.append_value(res);
}
}
Ok(Arc::new(boolean_builder.finish()))
}

/// Array_has_any SQL function
pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {
check_datatypes("array_has_any", &[&args[0], &args[1]])?;
/// Array_has SQL function
pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
let array_type = args[0].data_type();

let array = as_list_array(&args[0])?;
let sub_array = as_list_array(&args[1])?;
let mut boolean_builder = BooleanArray::builder(array.len());
match array_type {
DataType::List(_) => {
general_array_has_dispatch::<i32>(&args[0], &args[1], ComparisonType::Single)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::Single)
}
_ => internal_err!("array_has does not support type '{array_type:?}'."),
}
}

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
for (arr, sub_arr) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = converter.convert_columns(&[sub_arr])?;
/// Array_has_any SQL function
pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {
let array_type = args[0].data_type();

let mut res = false;
for elem in sub_arr_values.iter().dedup() {
res |= arr_values.iter().dedup().any(|x| x == elem);
if res {
break;
}
}
boolean_builder.append_value(res);
match array_type {
DataType::List(_) => {
general_array_has_dispatch::<i32>(&args[0], &args[1], ComparisonType::Any)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::Any)
}
_ => internal_err!("array_has_any does not support type '{array_type:?}'."),
}
Ok(Arc::new(boolean_builder.finish()))
}

/// Array_has_all SQL function
pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
check_datatypes("array_has_all", &[&args[0], &args[1]])?;

let array = as_list_array(&args[0])?;
let sub_array = as_list_array(&args[1])?;

let mut boolean_builder = BooleanArray::builder(array.len());

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
for (arr, sub_arr) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = converter.convert_columns(&[sub_arr])?;
let array_type = args[0].data_type();

let mut res = true;
for elem in sub_arr_values.iter().dedup() {
res &= arr_values.iter().dedup().any(|x| x == elem);
if !res {
break;
}
}
boolean_builder.append_value(res);
match array_type {
DataType::List(_) => {
general_array_has_dispatch::<i32>(&args[0], &args[1], ComparisonType::All)
}
DataType::LargeList(_) => {
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::All)
}
_ => internal_err!("array_has_all does not support type '{array_type:?}'."),
}
Ok(Arc::new(boolean_builder.finish()))
}

/// Splits string at occurrences of delimiter and returns an array of parts
Expand Down
111 changes: 111 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2538,6 +2538,23 @@ select array_has(make_array(1,2), 1),
----
true true true true true false true false true false true false

query BBBBBBBBBBBB
select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1),
array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1),
array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)),
array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])),
array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])),
array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])),
array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])),
array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])),
array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])),
list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4),
array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3),
list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0)
;
----
true true true true true false true false true false true false

query BBB
select array_has(column1, column2),
array_has_all(column3, column4),
Expand All @@ -2547,6 +2564,15 @@ from array_has_table_1D;
true true true
false false false

query BBB
select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2),
array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')),
array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)'))
from array_has_table_1D;
----
true true true
false false false

query BBB
select array_has(column1, column2),
array_has_all(column3, column4),
Expand All @@ -2556,6 +2582,15 @@ from array_has_table_1D_Float;
true true false
false false true

query BBB
select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2),
array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')),
array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)'))
from array_has_table_1D_Float;
----
true true false
false false true

query BBB
select array_has(column1, column2),
array_has_all(column3, column4),
Expand All @@ -2565,6 +2600,15 @@ from array_has_table_1D_Boolean;
false true true
true true true

query BBB
select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2),
array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')),
array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)'))
from array_has_table_1D_Boolean;
----
false true true
true true true

query BBB
select array_has(column1, column2),
array_has_all(column3, column4),
Expand All @@ -2574,6 +2618,15 @@ from array_has_table_1D_UTF8;
true true false
false false true

query BBB
select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2),
array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')),
array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)'))
from array_has_table_1D_UTF8;
----
true true false
false false true

query BB
select array_has(column1, column2),
array_has_all(column3, column4)
Expand All @@ -2582,13 +2635,28 @@ from array_has_table_2D;
false true
true false

query BB
select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2),
array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))'))
from array_has_table_2D;
----
false true
true false

query B
select array_has_all(column1, column2)
from array_has_table_2D_float;
----
true
false

query B
select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))'))
from array_has_table_2D_float;
----
true
false

query B
select array_has(column1, column2) from array_has_table_3D;
----
Expand All @@ -2600,6 +2668,17 @@ true
false
true

query B
select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D;
----
false
true
false
false
true
false
true

query BBBB
select array_has(column1, make_array(5, 6)),
array_has(column1, make_array(7, NULL)),
Expand All @@ -2614,6 +2693,20 @@ false true false false
false false false false
false false false false

query BBBB
select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)),
array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)),
array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5),
array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o')
from arrays;
----
false false false true
true false true false
true false false true
false true false false
false false false false
false false false false

query BBBBBBBBBBBBB
select array_has_all(make_array(1,2,3), make_array(1,3)),
array_has_all(make_array(1,2,3), make_array(1,4)),
Expand All @@ -2632,6 +2725,24 @@ select array_has_all(make_array(1,2,3), make_array(1,3)),
----
true false true false false false true true false false true false true

query BBBBBBBBBBBBB
select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')),
array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')),
array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')),
array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')),
array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')),
array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')),
array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')),
array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')),
array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')),
array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))'))
;
----
true false true false false false true true false false true false true

query ???
select array_intersect(column1, column2),
array_intersect(column3, column4),
Expand Down