diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index b9e82f371369a..ef6e426d77bfc 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_element, array_slice, array_pop_front and array_pop_back functions. +//! [`ScalarUDFImpl`] definitions for array_element, array_slice, array_pop_front, array_pop_back, and array_any_value functions. use arrow::array::Array; use arrow::array::ArrayRef; @@ -69,6 +69,14 @@ make_udf_expr_and_func!( array_pop_back_udf ); +make_udf_expr_and_func!( + ArrayAnyValue, + array_any_value, + array, + "returns the first non-null element in the array.", + array_any_value_udf +); + #[derive(Debug)] pub(super) struct ArrayElement { signature: Signature, @@ -687,3 +695,118 @@ where ); general_array_slice::(array, &from_array, &to_array, None) } + +#[derive(Debug)] +pub(super) struct ArrayAnyValue { + signature: Signature, + aliases: Vec, +} + +impl ArrayAnyValue { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("list_any_value")], + } + } +} + +impl ScalarUDFImpl for ArrayAnyValue { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_any_value" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + List(field) + | LargeList(field) + | FixedSizeList(field, _) => Ok(field.data_type().clone()), + _ => plan_err!( + "array_any_value can only accept List, LargeList or FixedSizeList as the argument" + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_any_value_inner)(args) + } + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn array_any_value_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_any_value expects one argument"); + } + + match &args[0].data_type() { + List(_) => { + let array = as_list_array(&args[0])?; + general_array_any_value::(array) + } + LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_array_any_value::(array) + } + _ => exec_err!( + "array_any_value does not support type: {:?}", + args[0].data_type() + ), + } +} + +fn general_array_any_value( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(array.len()); + + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; + let len = end - start; + + // array is null + if len == O::usize_as(0) { + mutable.extend_nulls(1); + continue; + } + + let row_value = array.value(row_index); + match row_value.nulls() { + Some(row_nulls_buffer) => { + // nulls are present in the array so try to take the first valid element + if let Some(first_non_null_index) = + row_nulls_buffer.valid_indices().next() + { + let index = start.as_usize() + first_non_null_index; + mutable.extend(0, index, index + 1) + } else { + // all the elements in the array are null + mutable.extend_nulls(1); + } + } + None => { + // no nulls are present in the array so take the first element + let index = start.as_usize(); + mutable.extend(0, index, index + 1); + } + } + } + + let data = mutable.freeze(); + Ok(arrow::array::make_array(data)) +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 863b5a876adc1..6f302a9671026 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -77,6 +77,7 @@ pub mod expr_fn { pub use super::distance::array_distance; pub use super::empty::array_empty; pub use super::except::array_except; + pub use super::extract::array_any_value; pub use super::extract::array_element; pub use super::extract::array_pop_back; pub use super::extract::array_pop_front; @@ -124,6 +125,7 @@ pub fn all_default_nested_functions() -> Vec> { extract::array_pop_back_udf(), extract::array_pop_front_udf(), extract::array_slice_udf(), + extract::array_any_value_udf(), make_array::make_array_udf(), array_has::array_has_udf(), array_has::array_has_all_udf(), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e174d1b507130..2553518e609cb 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -839,6 +839,12 @@ async fn roundtrip_expr_api() -> Result<()> { ), array_pop_front(make_array(vec![lit(1), lit(2), lit(3)])), array_pop_back(make_array(vec![lit(1), lit(2), lit(3)])), + array_any_value(make_array(vec![ + lit(ScalarValue::Null), + lit(1), + lit(2), + lit(3), + ])), array_reverse(make_array(vec![lit(1), lit(2), lit(3)])), array_position( make_array(vec![lit(1), lit(2), lit(3), lit(4)]), diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index edc0cd7577e1e..0e728eb92a198 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1982,6 +1982,133 @@ select array_slice(a, -1, 2, 1), array_slice(a, -1, 2), query error DataFusion error: Error during planning: Error during planning: array_slice does not support zero arguments select array_slice(); +## array_any_value (aliases: list_any_value) + +# Testing with empty arguments should result in an error +query error +select array_any_value(); + +# Testing with non-array arguments should result in an error +query error +select array_any_value(1), array_any_value('a'), array_any_value(NULL); + +# array_any_value scalar function #1 (with null and non-null elements) + +query ITI +select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o')), array_any_value(make_array(NULL, NULL)); +---- +1 h NULL + +query ITIT +select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), array_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'LargeList(Int64)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'LargeList(Utf8)')); +---- +1 h NULL NULL + +query ITIT +select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'FixedSizeList(6, Int64)')), array_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'FixedSizeList(6, Utf8)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Utf8)'));; +---- +1 h NULL NULL + +# array_any_value scalar function #2 (with nested array) + +query ? +select array_any_value(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10))); +---- +[, 1, 2, 3, 4, 5] + +query ? +select array_any_value(arrow_cast(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10)), 'LargeList(List(Int64))')); +---- +[, 1, 2, 3, 4, 5] + +query ? +select array_any_value(arrow_cast(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10)), 'FixedSizeList(3, List(Int64))')); +---- +[, 1, 2, 3, 4, 5] + +# array_any_value scalar function #3 (using function alias `list_any_value`) +query IT +select list_any_value(make_array(NULL, 1, 2, 3, 4, 5)), list_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o')); +---- +1 h + +query IT +select list_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), list_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +1 h + +query IT +select list_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'FixedSizeList(6, Int64)')), list_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'FixedSizeList(6, Utf8)')); +---- +1 h + +# array_any_value with columns + +query I +select array_any_value(column1) from slices; +---- +2 +11 +21 +31 +NULL +41 +51 + +query I +select array_any_value(arrow_cast(column1, 'LargeList(Int64)')) from slices; +---- +2 +11 +21 +31 +NULL +41 +51 + +query I +select array_any_value(column1) from fixed_slices; +---- +2 +11 +21 +31 +41 +51 + +# array_any_value with columns and scalars + +query II +select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(column1) from slices; +---- +1 2 +1 11 +1 21 +1 31 +1 NULL +1 41 +1 51 + +query II +select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), array_any_value(arrow_cast(column1, 'LargeList(Int64)')) from slices; +---- +1 2 +1 11 +1 21 +1 31 +1 NULL +1 41 +1 51 + +query II +select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(column1) from fixed_slices; +---- +1 2 +1 11 +1 21 +1 31 +1 41 +1 51 # make_array with nulls query ??????? diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index ad5a9cb75152c..d7deb5a3e0a0f 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -211,6 +211,7 @@ select log(-1), log(0), sqrt(-1); | Syntax | Description | | ---------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| array_any_value(array) | Returns the first non-null element in the array. `array_any_value([NULL, 1, 2, 3]) -> 1` | | array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | | array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | | array_has(array, element) | Returns true if the array contains the element `array_has([1,2,3], 1) -> true` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 9569f4e65ff3a..b66ffad0cf4a6 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2087,6 +2087,7 @@ to_unixtime(expression[, ..., format_n]) ## Array Functions +- [array_any_value](#array_any_value) - [array_append](#array_append) - [array_sort](#array_sort) - [array_cat](#array_cat) @@ -2175,6 +2176,30 @@ to_unixtime(expression[, ..., format_n]) - [unnest](#unnest) - [range](#range) +### `array_any_value` + +Appends an element to the end of an array. + +``` +array_any_value(array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +> select array_any_value([NULL, 1, 2, 3]); ++--------------------------------------------------------------+ +| array_any_value(List([NULL,1,2,3])) | ++--------------------------------------------------------------+ +| 1 | ++--------------------------------------------------------------+ +``` + ### `array_append` Appends an element to the end of an array. @@ -3240,6 +3265,10 @@ generate_series(start, stop, step) +------------------------------------+ ``` +### `list_any_value` + +_Alias of [array_any_value](#array_any_value)._ + ### `list_append` _Alias of [array_append](#array_append)._