Skip to content

Commit

Permalink
Added array_any_value function
Browse files Browse the repository at this point in the history
  • Loading branch information
athultr1997 committed Sep 5, 2024
1 parent 650dfdc commit 21b9524
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 1 deletion.
122 changes: 121 additions & 1 deletion datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

//! [`ScalarUDFImpl`] definitions for array_element, array_slice, array_pop_front and array_pop_back functions.
//! [`ScalarUDFImpl`] definitions for array_element, array_slice, array_pop_front, array_pop_back, and array_any_value functions.

use arrow::array::Array;
use arrow::array::ArrayRef;
Expand Down Expand Up @@ -69,6 +69,14 @@ make_udf_expr_and_func!(
array_pop_back_udf
);

make_udf_expr_and_func!(
ArrayAnyValue,
array_any_value,
array,
"returns the first non-null element in the array.",
array_any_value_udf
);

#[derive(Debug)]
pub(super) struct ArrayElement {
signature: Signature,
Expand Down Expand Up @@ -687,3 +695,115 @@ where
);
general_array_slice::<O>(array, &from_array, &to_array, None)
}

#[derive(Debug)]
pub(super) struct ArrayAnyValue {
signature: Signature,
aliases: Vec<String>,
}

impl ArrayAnyValue {
pub fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable),
aliases: vec![String::from("list_any_value")],
}
}
}

impl ScalarUDFImpl for ArrayAnyValue {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_any_value"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
List(field)
| LargeList(field)
| FixedSizeList(field, _) => Ok(field.data_type().clone()),
_ => plan_err!(
"array_any_value can only accept List, LargeList or FixedSizeList as the argument"
),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(array_any_value_inner)(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}

fn array_any_value_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_any_value expects one argument");
}

match &args[0].data_type() {
List(_) => {
let array = as_list_array(&args[0])?;
general_array_any_value::<i32>(array)
}
LargeList(_) => {
let array = as_large_list_array(&args[0])?;
general_array_any_value::<i64>(array)
}
data_type => exec_err!("array_any_value does not support type: {:?}", data_type),
}
}

fn general_array_any_value<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
) -> Result<ArrayRef>
where
i64: TryInto<O>,
{
let values = array.values();
let original_data = values.to_data();
let capacity = Capacities::Array(array.len());

let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);

for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
let start = offset_window[0];
let end = offset_window[1];
let len = end - start;

// array is null
if len == O::usize_as(0) {
mutable.extend_nulls(1);
continue;
}

let row_value = array.value(row_index);
match row_value.nulls() {
Some(row_nulls_buffer) => {
// nulls are present in the array so try to take the first valid element
if let Some(first_non_null_index) =
row_nulls_buffer.valid_indices().next()
{
let index = start.as_usize() + first_non_null_index;
mutable.extend(0, index, index + 1)
} else {
// all the elements in the array are null
mutable.extend_nulls(1);
}
}
None => {
// no nulls are present in the array so take the first element
let index = start.as_usize();
mutable.extend(0, index, index + 1);
}
}
}

let data = mutable.freeze();
Ok(arrow::array::make_array(data))
}
2 changes: 2 additions & 0 deletions datafusion/functions-nested/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub mod expr_fn {
pub use super::distance::array_distance;
pub use super::empty::array_empty;
pub use super::except::array_except;
pub use super::extract::array_any_value;
pub use super::extract::array_element;
pub use super::extract::array_pop_back;
pub use super::extract::array_pop_front;
Expand Down Expand Up @@ -124,6 +125,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
extract::array_pop_back_udf(),
extract::array_pop_front_udf(),
extract::array_slice_udf(),
extract::array_any_value_udf(),
make_array::make_array_udf(),
array_has::array_has_udf(),
array_has::array_has_all_udf(),
Expand Down
6 changes: 6 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,12 @@ async fn roundtrip_expr_api() -> Result<()> {
),
array_pop_front(make_array(vec![lit(1), lit(2), lit(3)])),
array_pop_back(make_array(vec![lit(1), lit(2), lit(3)])),
array_any_value(make_array(vec![
lit(ScalarValue::Null),
lit(1),
lit(2),
lit(3),
])),
array_reverse(make_array(vec![lit(1), lit(2), lit(3)])),
array_position(
make_array(vec![lit(1), lit(2), lit(3), lit(4)]),
Expand Down
127 changes: 127 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,133 @@ select array_slice(a, -1, 2, 1), array_slice(a, -1, 2),
query error DataFusion error: Error during planning: Error during planning: array_slice does not support zero arguments
select array_slice();

## array_any_value (aliases: list_any_value)

# Testing with empty arguments should result in an error
query error
select array_any_value();

# Testing with non-array arguments should result in an error
query error
select array_any_value(1), array_any_value('a'), array_any_value(NULL);

# array_any_value scalar function #1 (with null and non-null elements)

query ITI
select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o')), array_any_value(make_array(NULL, NULL));
----
1 h NULL

query ITIT
select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), array_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'LargeList(Int64)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'LargeList(Utf8)'));
----
1 h NULL NULL

query ITIT
select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'FixedSizeList(6, Int64)')), array_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'FixedSizeList(6, Utf8)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Utf8)'));;
----
1 h NULL NULL

# array_any_value scalar function #2 (with nested array)

query ?
select array_any_value(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10)));
----
[, 1, 2, 3, 4, 5]

query ?
select array_any_value(arrow_cast(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10)), 'LargeList(List(Int64))'));
----
[, 1, 2, 3, 4, 5]

query ?
select array_any_value(arrow_cast(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10)), 'FixedSizeList(3, List(Int64))'));
----
[, 1, 2, 3, 4, 5]

# array_any_value scalar function #3 (using function alias `list_any_value`)
query IT
select list_any_value(make_array(NULL, 1, 2, 3, 4, 5)), list_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o'));
----
1 h

query IT
select list_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), list_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'));
----
1 h

query IT
select list_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'FixedSizeList(6, Int64)')), list_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'FixedSizeList(6, Utf8)'));
----
1 h

# array_any_value with columns

query I
select array_any_value(column1) from slices;
----
2
11
21
31
NULL
41
51

query I
select array_any_value(arrow_cast(column1, 'LargeList(Int64)')) from slices;
----
2
11
21
31
NULL
41
51

query I
select array_any_value(column1) from fixed_slices;
----
2
11
21
31
41
51

# array_any_value with columns and scalars

query II
select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(column1) from slices;
----
1 2
1 11
1 21
1 31
1 NULL
1 41
1 51

query II
select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), array_any_value(arrow_cast(column1, 'LargeList(Int64)')) from slices;
----
1 2
1 11
1 21
1 31
1 NULL
1 41
1 51

query II
select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(column1) from fixed_slices;
----
1 2
1 11
1 21
1 31
1 41
1 51

# make_array with nulls
query ???????
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ select log(-1), log(0), sqrt(-1);

| Syntax | Description |
| ---------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| array_any_value(array) | Returns the first non-null element in the array. `array_any_value([NULL, 1, 2, 3]) -> 1` |
| array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` |
| array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` |
| array_has(array, element) | Returns true if the array contains the element `array_has([1,2,3], 1) -> true` |
Expand Down
30 changes: 30 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2087,6 +2087,7 @@ to_unixtime(expression[, ..., format_n])

## Array Functions

- [array_any_value](#array_any_value)
- [array_append](#array_append)
- [array_sort](#array_sort)
- [array_cat](#array_cat)
Expand Down Expand Up @@ -2131,6 +2132,7 @@ to_unixtime(expression[, ..., format_n])
- [empty](#empty)
- [flatten](#flatten)
- [generate_series](#generate_series)
- [list_any_value] (#list_any_value)
- [list_append](#list_append)
- [list_sort](#list_sort)
- [list_cat](#list_cat)
Expand Down Expand Up @@ -2175,6 +2177,30 @@ to_unixtime(expression[, ..., format_n])
- [unnest](#unnest)
- [range](#range)

### `array_any_value`

Appends an element to the end of an array.

```
array_any_value(array)
```

#### Arguments

- **array**: Array expression.
Can be a constant, column, or function, and any combination of array operators.

#### Example

```
> select array_any_value([NULL, 1, 2, 3]);
+--------------------------------------------------------------+
| array_any_value(List([NULL,1,2,3])) |
+--------------------------------------------------------------+
| 1 |
+--------------------------------------------------------------+
```

### `array_append`

Appends an element to the end of an array.
Expand Down Expand Up @@ -3240,6 +3266,10 @@ generate_series(start, stop, step)
+------------------------------------+
```

### `list_any_value`

_Alias of [array_any_value](#array_any_value)._

### `list_append`

_Alias of [array_append](#array_append)._
Expand Down

0 comments on commit 21b9524

Please sign in to comment.