Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: New functions and operations for working with arrays #6384

Merged
merged 14 commits into from
Jun 6, 2023
7 changes: 6 additions & 1 deletion datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use arrow::{
IntervalYearMonthArray, LargeListArray, ListArray, MapArray, NullArray,
OffsetSizeTrait, PrimitiveArray, StringArray, StructArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt32Array, UInt64Array, UnionArray,
TimestampSecondArray, UInt32Array, UInt64Array, UInt8Array, UnionArray,
},
datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
};
Expand All @@ -45,6 +45,11 @@ pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray> {
Ok(downcast_value!(array, StructArray))
}

// Downcast ArrayRef to UInt8Array
pub fn as_uint8_array(array: &dyn Array) -> Result<&UInt8Array> {
Ok(downcast_value!(array, UInt8Array))
}

// Downcast ArrayRef to Int32Array
pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array> {
Ok(downcast_value!(array, Int32Array))
Expand Down
206 changes: 206 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

#############
## Array expressions Tests
#############

# array scalar function #1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are great @izveigor -- thank you so much

the only thing I recommend is adding some additional tests that have null in the lists.

query ??? rowsort
select make_array(1, 2, 3), make_array(1.0, 2.0, 3.0), make_array('h', 'e', 'l', 'l', 'o');
----
[1, 2, 3] [1.0, 2.0, 3.0] [h, e, l, l, o]

# array scalar function #2
query ??? rowsort
select make_array(1, 2, 3), make_array(make_array(1, 2), make_array(3, 4)), make_array([[[[1], [2]]]]);
----
[1, 2, 3] [[1, 2], [3, 4]] [[[[[1], [2]]]]]

# array scalar function #3
query ?? rowsort
select make_array([1, 2, 3], [4, 5, 6], [7, 8, 9]), make_array([[1, 2], [3, 4]], [[5, 6], [7, 8]]);
----
[[1, 2, 3], [4, 5, 6], [7, 8, 9]] [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]

# array scalar function #4
query ?? rowsort
select make_array([1.0, 2.0], [3.0, 4.0]), make_array('h', 'e', 'l', 'l', 'o');
----
[[1.0, 2.0], [3.0, 4.0]] [h, e, l, l, o]

# array scalar function #5
query ? rowsort
select make_array(make_array(make_array(make_array(1, 2, 3), make_array(4, 5, 6)), make_array(make_array(7, 8, 9), make_array(10, 11, 12))))
----
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]

# array_append scalar function
query ??? rowsort
select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3.0), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o');
----
[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]

# array_prepend scalar function
query ??? rowsort
select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, 3.0, 4.0)), array_prepend('h', make_array('e', 'l', 'l', 'o'));
----
[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]

# array_fill scalar function #1
query ??? rowsort
select array_fill(11, make_array(1, 2, 3)), array_fill(3, make_array(2, 3)), array_fill(2, make_array(2));
----
[[[11, 11, 11], [11, 11, 11]]] [[3, 3, 3], [3, 3, 3]] [2, 2]

# array_fill scalar function #2
query ?? rowsort
select array_fill(1, make_array(1, 1, 1)), array_fill(2, make_array(2, 2, 2, 2, 2));
----
[[[1]]] [[[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]], [[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]]]

# array_concat scalar function #1
query ?? rowsort
select array_concat(make_array(1, 2, 3), make_array(4, 5, 6), make_array(7, 8, 9)), array_concat(make_array([1], [2]), make_array([3], [4]));
----
[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]]

# array_concat scalar function #2
query ? rowsort
select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array(5, 6), make_array(7, 8)));
----
[[1, 2], [3, 4], [5, 6], [7, 8]]

# array_concat scalar function #3
query ? rowsort
select array_concat(make_array([1], [2], [3]), make_array([4], [5], [6]), make_array([7], [8], [9]));
----
[[1], [2], [3], [4], [5], [6], [7], [8], [9]]

# array_concat scalar function #4
query ? rowsort
select array_concat(make_array([[1]]), make_array([[2]]));
----
[[[1]], [[2]]]

# array_position scalar function #1
query III
select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, 4, 5], 5), array_position([1, 1, 1], 1);
----
3 5 1

# array_position scalar function #2
query III
select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2);
----
4 5 2

# array_positions scalar function
query III
select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1);
----
[3, 4] [5] [1, 2, 3]

# array_replace scalar function
query ???
select array_replace(make_array(1, 2, 3, 4), 2, 3), array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), array_replace(make_array(1, 2, 3), 4, 0);
----
[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3]

# array_to_string scalar function
query ???
select array_to_string(['h', 'e', 'l', 'l', 'o'], ','), array_to_string([1, 2, 3, 4, 5], '-'), array_to_string([1.0, 2.0, 3.0], '|');
----
h,e,l,l,o 1-2-3-4-5 1|2|3

# array_to_string scalar function #2
query ???
select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_fill(3, [3, 2, 2]), '/\');
----
11111 1+2+3+4+5+6 3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3

# cardinality scalar function
query III
select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinality(make_array('h', 'e', 'l', 'l', 'o'));
----
5 3 5

# cardinality scalar function #2
query II
select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_fill(3, array[3, 2, 3]));
----
6 18

# trim_array scalar function
query ???
select trim_array(make_array(1, 2, 3, 4, 5), 2), trim_array(['h', 'e', 'l', 'l', 'o'], 3), trim_array([1.0, 2.0, 3.0], 2);
----
[1, 2, 3] [h, e] [1.0]

# trim_array scalar function #2
query ??
select trim_array([[1, 2], [3, 4], [5, 6]], 2), trim_array(array_fill(4, [3, 4, 2]), 2);
----
[[1, 2]] [[[4, 4], [4, 4], [4, 4], [4, 4]]]

# array_length scalar function
query III rowsort
select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3)), array_length(make_array([1, 2], [3, 4], [5, 6]));
----
5 3 3

# array_length scalar function #2
query III rowsort
select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1);
----
5 3 3

# array_length scalar function #3
query III rowsort
select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2);
----
NULL NULL 2

# array_length scalar function #4
query IIII rowsort
select array_length(array_fill(3, [3, 2, 5]), 1), array_length(array_fill(3, [3, 2, 5]), 2), array_length(array_fill(3, [3, 2, 5]), 3), array_length(array_fill(3, [3, 2, 5]), 4);
----
3 2 5 NULL

# array_dims scalar function
query III rowsort
select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), array_dims(make_array([[[[1], [2]]]]));
----
[3] [2, 2] [1, 1, 1, 2, 1]

# array_dims scalar function #2
query II rowsort
select array_dims(array_fill(2, [1, 2, 3])), array_dims(array_fill(3, [2, 5, 4]));
----
[1, 2, 3] [2, 5, 4]

# array_ndims scalar function
query III rowsort
select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]]));
----
1 2 5

# array_ndims scalar function #2
query II rowsort
select array_ndims(array_fill(1, [1, 2, 3])), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]);
----
3 21
60 changes: 59 additions & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,39 @@ pub enum BuiltinScalarFunction {
/// trunc
Trunc,

// string functions
// array functions
/// array_append
ArrayAppend,
/// array_concat
ArrayConcat,
/// array_dims
ArrayDims,
/// array_fill
ArrayFill,
/// array_length
ArrayLength,
/// array_ndims
ArrayNdims,
/// array_position
ArrayPosition,
/// array_positions
ArrayPositions,
/// array_prepend
ArrayPrepend,
/// array_remove
ArrayRemove,
/// array_replace
ArrayReplace,
/// array_to_string
ArrayToString,
/// cardinality
Cardinality,
/// construct an array from columns
MakeArray,
/// trim_array
TrimArray,

// string functions
/// ascii
Ascii,
/// bit_length
Expand Down Expand Up @@ -313,7 +343,21 @@ lazy_static! {
("arrow_typeof", BuiltinScalarFunction::ArrowTypeof),

// array functions
("array_append", BuiltinScalarFunction::ArrayAppend),
("array_concat", BuiltinScalarFunction::ArrayConcat),
("array_dims", BuiltinScalarFunction::ArrayDims),
("array_fill", BuiltinScalarFunction::ArrayFill),
("array_length", BuiltinScalarFunction::ArrayLength),
("array_ndims", BuiltinScalarFunction::ArrayNdims),
("array_position", BuiltinScalarFunction::ArrayPosition),
("array_positions", BuiltinScalarFunction::ArrayPositions),
("array_prepend", BuiltinScalarFunction::ArrayPrepend),
("array_remove", BuiltinScalarFunction::ArrayRemove),
("array_replace", BuiltinScalarFunction::ArrayReplace),
("array_to_string", BuiltinScalarFunction::ArrayToString),
("cardinality", BuiltinScalarFunction::Cardinality),
("make_array", BuiltinScalarFunction::MakeArray),
("trim_array", BuiltinScalarFunction::TrimArray),
];
}

Expand Down Expand Up @@ -368,7 +412,21 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Tan => Volatility::Immutable,
BuiltinScalarFunction::Tanh => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArrayAppend => Volatility::Immutable,
BuiltinScalarFunction::ArrayConcat => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
BuiltinScalarFunction::ArrayFill => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
BuiltinScalarFunction::ArrayNdims => Volatility::Immutable,
BuiltinScalarFunction::ArrayPosition => Volatility::Immutable,
BuiltinScalarFunction::ArrayPositions => Volatility::Immutable,
BuiltinScalarFunction::ArrayPrepend => Volatility::Immutable,
BuiltinScalarFunction::ArrayRemove => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplace => Volatility::Immutable,
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
BuiltinScalarFunction::TrimArray => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
BuiltinScalarFunction::BitLength => Volatility::Immutable,
BuiltinScalarFunction::Btrim => Volatility::Immutable,
Expand Down
Loading