From 8f551d3acea78c32e27cd6905d266e7008c9824a Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Sat, 25 Nov 2023 19:45:04 +0100 Subject: [PATCH 1/2] support LargeList for array_has, array_has_all and array_has_any --- .../physical-expr/src/array_expressions.rs | 115 +++++++++++------- datafusion/sqllogictest/test_files/array.slt | 111 +++++++++++++++++ 2 files changed, 184 insertions(+), 42 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 103a392b199d..f753ca5996fe 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1774,11 +1774,67 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// general internal function for array_has_all and array_has_all +/// if is_all is true, then it is array_has_all, otherwise it is array_has_any +fn general_array_has_dispatch( + array: &ArrayRef, + sub_array: &ArrayRef, + is_all: bool, +) -> Result { + check_datatypes("array_has", &[array, sub_array])?; + + let array = as_generic_list_array::(array)?; + let sub_array = as_generic_list_array::(&sub_array)?; + + let mut boolean_builder = BooleanArray::builder(array.len()); + + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; + for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + + let mut res = if is_all { + sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)) + } else { + sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)) + }; + + if is_all { + res |= res; + } + + boolean_builder.append_value(res); + } + } + Ok(Arc::new(boolean_builder.finish())) +} + /// Array_has SQL function pub fn array_has(args: &[ArrayRef]) -> Result { - let array = as_list_array(&args[0])?; + let array_type = args[0].data_type(); + let array = &args[0]; let element = &args[1]; + match array_type { + DataType::List(_) => array_has_dispatch::(array, element), + DataType::LargeList(_) => array_has_dispatch::(array, element), + _ => internal_err!("array_has does not support type '{array_type:?}'."), + } +} + +fn array_has_dispatch( + array: &ArrayRef, + element: &ArrayRef, +) -> Result { + let array = as_generic_list_array::(array)?; + check_datatypes("array_has", &[array.values(), element])?; let mut boolean_builder = BooleanArray::builder(array.len()); @@ -1799,57 +1855,32 @@ pub fn array_has(args: &[ArrayRef]) -> Result { /// Array_has_any SQL function pub fn array_has_any(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_any", &[&args[0], &args[1]])?; - - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - let mut boolean_builder = BooleanArray::builder(array.len()); - - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; + let array_type = args[0].data_type(); + let array = &args[0]; + let sub_array = &args[1]; - let mut res = false; - for elem in sub_arr_values.iter().dedup() { - res |= arr_values.iter().dedup().any(|x| x == elem); - if res { - break; - } - } - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => general_array_has_dispatch::(array, sub_array, false), + DataType::LargeList(_) => { + general_array_has_dispatch::(array, sub_array, false) } + _ => internal_err!("array_has_any does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_all", &[&args[0], &args[1]])?; - - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - - let mut boolean_builder = BooleanArray::builder(array.len()); - - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; + let array_type = args[0].data_type(); + let array = &args[0]; + let sub_array = &args[1]; - let mut res = true; - for elem in sub_arr_values.iter().dedup() { - res &= arr_values.iter().dedup().any(|x| x == elem); - if !res { - break; - } - } - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => general_array_has_dispatch::(array, sub_array, true), + DataType::LargeList(_) => { + general_array_has_dispatch::(array, sub_array, true) } + _ => internal_err!("array_has_all does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Splits string at occurrences of delimiter and returns an array of parts diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 3b45d995e1a2..5c6c4dbf68f0 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2538,6 +2538,23 @@ select array_has(make_array(1,2), 1), ---- true true true true true false true false true false true false +query BBBBBBBBBBBB +select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])), + array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])), + list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4), + array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3), + list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0) +; +---- +true true true true true false true false true false true false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2547,6 +2564,15 @@ from array_has_table_1D; true true true false false false +query BBB +select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')), + array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)')) +from array_has_table_1D; +---- +true true true +false false false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2556,6 +2582,15 @@ from array_has_table_1D_Float; true true false false false true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')), + array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)')) +from array_has_table_1D_Float; +---- +true true false +false false true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2565,6 +2600,15 @@ from array_has_table_1D_Boolean; false true true true true true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)')) +from array_has_table_1D_Boolean; +---- +false true true +true true true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2574,6 +2618,15 @@ from array_has_table_1D_UTF8; true true false false false true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')), + array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)')) +from array_has_table_1D_UTF8; +---- +true true false +false false true + query BB select array_has(column1, column2), array_has_all(column3, column4) @@ -2582,6 +2635,14 @@ from array_has_table_2D; false true true false +query BB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2), + array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from array_has_table_2D; +---- +false true +true false + query B select array_has_all(column1, column2) from array_has_table_2D_float; @@ -2589,6 +2650,13 @@ from array_has_table_2D_float; true false +query B +select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))')) +from array_has_table_2D_float; +---- +true +false + query B select array_has(column1, column2) from array_has_table_3D; ---- @@ -2600,6 +2668,17 @@ true false true +query B +select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + query BBBB select array_has(column1, make_array(5, 6)), array_has(column1, make_array(7, NULL)), @@ -2614,6 +2693,20 @@ false true false false false false false false false false false false +query BBBB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), + array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)), + array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5), + array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + query BBBBBBBBBBBBB select array_has_all(make_array(1,2,3), make_array(1,3)), array_has_all(make_array(1,2,3), make_array(1,4)), @@ -2632,6 +2725,24 @@ select array_has_all(make_array(1,2,3), make_array(1,3)), ---- true false true false false false true true false false true false true +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + query ??? select array_intersect(column1, column2), array_intersect(column3, column4), From 8d6e4b51f6ab7711a35857a41143d550c812b27c Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Fri, 1 Dec 2023 09:39:21 +0100 Subject: [PATCH 2/2] simplify the code --- .../physical-expr/src/array_expressions.rs | 110 +++++++++--------- 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index f753ca5996fe..9ee597fab2c1 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1774,39 +1774,67 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// general internal function for array_has_all and array_has_all -/// if is_all is true, then it is array_has_all, otherwise it is array_has_any +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, + // array_has + Single, +} + fn general_array_has_dispatch( array: &ArrayRef, sub_array: &ArrayRef, - is_all: bool, + comparison_type: ComparisonType, ) -> Result { - check_datatypes("array_has", &[array, sub_array])?; - - let array = as_generic_list_array::(array)?; - let sub_array = as_generic_list_array::(&sub_array)?; + let array = if comparison_type == ComparisonType::Single { + let arr = as_generic_list_array::(array)?; + check_datatypes("array_has", &[arr.values(), sub_array])?; + arr + } else { + check_datatypes("array_has", &[array, sub_array])?; + as_generic_list_array::(array)? + }; let mut boolean_builder = BooleanArray::builder(array.len()); let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { + + let element = sub_array.clone(); + let sub_array = if comparison_type != ComparisonType::Single { + as_generic_list_array::(sub_array)? + } else { + array + }; + + for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; - let mut res = if is_all { - sub_arr_values + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values .iter() .dedup() - .all(|elem| arr_values.iter().dedup().any(|x| x == elem)) - } else { - sub_arr_values + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values .iter() .dedup() - .any(|elem| arr_values.iter().dedup().any(|x| x == elem)) + .any(|x| x == sub_arr_values.row(row_idx)), }; - if is_all { + if comparison_type == ComparisonType::Any { res |= res; } @@ -1819,50 +1847,28 @@ fn general_array_has_dispatch( /// Array_has SQL function pub fn array_has(args: &[ArrayRef]) -> Result { let array_type = args[0].data_type(); - let array = &args[0]; - let element = &args[1]; match array_type { - DataType::List(_) => array_has_dispatch::(array, element), - DataType::LargeList(_) => array_has_dispatch::(array, element), - _ => internal_err!("array_has does not support type '{array_type:?}'."), - } -} - -fn array_has_dispatch( - array: &ArrayRef, - element: &ArrayRef, -) -> Result { - let array = as_generic_list_array::(array)?; - - check_datatypes("array_has", &[array.values(), element])?; - let mut boolean_builder = BooleanArray::builder(array.len()); - - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - let r_values = converter.convert_columns(&[element.clone()])?; - for (row_idx, arr) in array.iter().enumerate() { - if let Some(arr) = arr { - let arr_values = converter.convert_columns(&[arr])?; - let res = arr_values - .iter() - .dedup() - .any(|x| x == r_values.row(row_idx)); - boolean_builder.append_value(res); + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + _ => internal_err!("array_has does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Array_has_any SQL function pub fn array_has_any(args: &[ArrayRef]) -> Result { let array_type = args[0].data_type(); - let array = &args[0]; - let sub_array = &args[1]; match array_type { - DataType::List(_) => general_array_has_dispatch::(array, sub_array, false), + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } DataType::LargeList(_) => { - general_array_has_dispatch::(array, sub_array, false) + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) } _ => internal_err!("array_has_any does not support type '{array_type:?}'."), } @@ -1871,13 +1877,13 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { let array_type = args[0].data_type(); - let array = &args[0]; - let sub_array = &args[1]; match array_type { - DataType::List(_) => general_array_has_dispatch::(array, sub_array, true), + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } DataType::LargeList(_) => { - general_array_has_dispatch::(array, sub_array, true) + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) } _ => internal_err!("array_has_all does not support type '{array_type:?}'."), }