diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 56f3029a4d7a..1bfae28af840 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -19,6 +19,7 @@ //! and return types of functions in DataFusion. use std::fmt::Display; +use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; @@ -236,9 +237,9 @@ pub enum ArrayFunctionSignature { /// The first argument should be non-list or list, and the second argument should be List/LargeList. /// The first argument's list dimension should be one dimension less than the second argument's list dimension. ElementAndArray, - /// Specialized Signature for Array functions of the form (List/LargeList, Index) - /// The first argument should be List/LargeList/FixedSizedList, and the second argument should be Int64. - ArrayAndIndex, + /// Specialized Signature for Array functions of the form (List/LargeList, Index+) + /// The first argument should be List/LargeList/FixedSizedList, and the next n arguments should be Int64. + ArrayAndIndexes(NonZeroUsize), /// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index) ArrayAndElementAndOptionalIndex, /// Specialized Signature for ArrayEmpty and similar functions @@ -265,8 +266,12 @@ impl Display for ArrayFunctionSignature { ArrayFunctionSignature::ElementAndArray => { write!(f, "element, array") } - ArrayFunctionSignature::ArrayAndIndex => { - write!(f, "array, index") + ArrayFunctionSignature::ArrayAndIndexes(count) => { + write!(f, "array")?; + for _ in 0..count.get() { + write!(f, ", index")?; + } + Ok(()) } ArrayFunctionSignature::Array => { write!(f, "array") @@ -600,9 +605,13 @@ impl Signature { } /// Specialized Signature for ArrayElement and similar functions pub fn array_and_index(volatility: Volatility) -> Self { + Self::array_and_indexes(volatility, NonZeroUsize::new(1).expect("1 is non-zero")) + } + /// Specialized Signature for ArraySlice and similar functions + pub fn array_and_indexes(volatility: Volatility, count: NonZeroUsize) -> Self { Signature { type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndex, + ArrayFunctionSignature::ArrayAndIndexes(count), ), volatility, } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 017415da8f23..b9a309faed59 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -94,8 +94,8 @@ pub use udaf::{ aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, }; pub use udf::{ - scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, + scalar_doc_sections, NullHandling, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, }; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 650619e6de4c..8cebf7c3db12 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -671,13 +671,20 @@ fn get_valid_types( ArrayFunctionSignature::ElementAndArray => { array_append_or_prepend_valid_types(current_types, false)? } - ArrayFunctionSignature::ArrayAndIndex => { - if current_types.len() != 2 { + ArrayFunctionSignature::ArrayAndIndexes(count) => { + if current_types.len() != count.get() + 1 { return Ok(vec![vec![]]); } array(¤t_types[0]).map_or_else( || vec![vec![]], - |array_type| vec![vec![array_type, DataType::Int64]], + |array_type| { + let mut inner = Vec::with_capacity(count.get() + 1); + inner.push(array_type); + for _ in 0..count.get() { + inner.push(DataType::Int64); + } + vec![inner] + }, ) } ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 56c9822495f8..44c5f6cfbe82 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -389,7 +389,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Whether the aggregate function is nullable. /// - /// Nullable means that that the function could return `null` for any inputs. + /// Nullable means that the function could return `null` for any inputs. /// For example, aggregate functions like `COUNT` always return a non null value /// but others like `MIN` will return `NULL` if there is nullable input. /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null` diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index aa6a5cddad95..7c91b6b3b4ab 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -200,6 +200,11 @@ impl ScalarUDF { self.inner.return_type_from_args(args) } + /// Returns the behavior that this function has when any of the inputs are Null. + pub fn null_handling(&self) -> NullHandling { + self.inner.null_handling() + } + /// Do the function rewrite /// /// See [`ScalarUDFImpl::simplify`] for more details. @@ -417,6 +422,15 @@ impl ReturnInfo { } } +/// A function's behavior when the input is Null. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +pub enum NullHandling { + /// Null inputs are passed into the function implementation. + PassThrough, + /// Any Null input causes the function to return Null. + Propagate, +} + /// Trait for implementing user defined scalar functions. /// /// This trait exposes the full API for implementing user defined functions and @@ -589,6 +603,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { true } + /// Returns the behavior that this function has when any of the inputs are Null. + fn null_handling(&self) -> NullHandling { + NullHandling::PassThrough + } + /// Invoke the function on `args`, returning the appropriate result /// /// Note: This method is deprecated and will be removed in future releases. diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index cce10d2bf6db..c87a96dca7a4 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -27,6 +27,7 @@ use arrow::array::MutableArrayData; use arrow::array::OffsetSizeTrait; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; +use arrow_buffer::NullBufferBuilder; use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::Field; use datafusion_common::cast::as_int64_array; @@ -35,12 +36,13 @@ use datafusion_common::cast::as_list_array; use datafusion_common::{ exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, }; -use datafusion_expr::Expr; +use datafusion_expr::{ArrayFunctionSignature, Expr, TypeSignature}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, NullHandling, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; +use std::num::NonZeroUsize; use std::sync::Arc; use crate::utils::make_scalar_function; @@ -330,7 +332,26 @@ pub(super) struct ArraySlice { impl ArraySlice { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::ArraySignature( + ArrayFunctionSignature::ArrayAndIndexes( + NonZeroUsize::new(1).expect("1 is non-zero"), + ), + ), + TypeSignature::ArraySignature( + ArrayFunctionSignature::ArrayAndIndexes( + NonZeroUsize::new(2).expect("2 is non-zero"), + ), + ), + TypeSignature::ArraySignature( + ArrayFunctionSignature::ArrayAndIndexes( + NonZeroUsize::new(3).expect("3 is non-zero"), + ), + ), + ], + Volatility::Immutable, + ), aliases: vec![String::from("list_slice")], } } @@ -374,6 +395,10 @@ impl ScalarUDFImpl for ArraySlice { Ok(arg_types[0].clone()) } + fn null_handling(&self) -> NullHandling { + NullHandling::Propagate + } + fn invoke_batch( &self, args: &[ColumnarValue], @@ -430,8 +455,6 @@ fn array_slice_inner(args: &[ArrayRef]) -> Result { } LargeList(_) => { let array = as_large_list_array(&args[0])?; - let from_array = as_int64_array(&args[1])?; - let to_array = as_int64_array(&args[2])?; general_array_slice::(array, from_array, to_array, stride) } _ => exec_err!("array_slice does not support type: {:?}", array_data_type), @@ -451,9 +474,8 @@ where let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); - // use_nulls: false, we don't need nulls but empty array for array_slice, so we don't need explicit nulls but adjust offset to indicate nulls. let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], false, capacity); + MutableArrayData::with_capacities(vec![&original_data], true, capacity); // We have the slice syntax compatible with DuckDB v0.8.1. // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. @@ -516,30 +538,33 @@ where } let mut offsets = vec![O::usize_as(0)]; + let mut null_builder = NullBufferBuilder::new(array.len()); for (row_index, offset_window) in array.offsets().windows(2).enumerate() { let start = offset_window[0]; let end = offset_window[1]; let len = end - start; - // len 0 indicate array is null, return empty array in this row. + // If any input is null, return null. + if array.is_null(row_index) + || from_array.is_null(row_index) + || to_array.is_null(row_index) + { + mutable.extend_nulls(1); + offsets.push(offsets[row_index] + O::usize_as(1)); + null_builder.append_null(); + continue; + } + null_builder.append_non_null(); + + // Empty arrays always return an empty array. if len == O::usize_as(0) { offsets.push(offsets[row_index]); continue; } - // If index is null, we consider it as the minimum / maximum index of the array. - let from_index = if from_array.is_null(row_index) { - Some(O::usize_as(0)) - } else { - adjusted_from_index::(from_array.value(row_index), len)? - }; - - let to_index = if to_array.is_null(row_index) { - Some(len - O::usize_as(1)) - } else { - adjusted_to_index::(to_array.value(row_index), len)? - }; + let from_index = adjusted_from_index::(from_array.value(row_index), len)?; + let to_index = adjusted_to_index::(to_array.value(row_index), len)?; if let (Some(from), Some(to)) = (from_index, to_index) { let stride = stride.map(|s| s.value(row_index)); @@ -613,7 +638,7 @@ where Arc::new(Field::new_list_field(array.value_type(), true)), OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), - None, + null_builder.finish(), )?)) } @@ -665,6 +690,10 @@ impl ScalarUDFImpl for ArrayPopFront { Ok(arg_types[0].clone()) } + fn null_handling(&self) -> NullHandling { + NullHandling::Propagate + } + fn invoke_batch( &self, args: &[ColumnarValue], @@ -765,6 +794,10 @@ impl ScalarUDFImpl for ArrayPopBack { Ok(arg_types[0].clone()) } + fn null_handling(&self) -> NullHandling { + NullHandling::Propagate + } + fn invoke_batch( &self, args: &[ColumnarValue], diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 936adbc098d6..1cd4b673ce7f 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -45,7 +45,8 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, Expr, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + expr_vec_fmt, ColumnarValue, Expr, NullHandling, ReturnTypeArgs, ScalarFunctionArgs, + ScalarUDF, }; /// Physical expression of a scalar function @@ -186,6 +187,15 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; + if self.fun.null_handling() == NullHandling::Propagate + && args.iter().any( + |arg| matches!(arg, ColumnarValue::Scalar(scalar) if scalar.is_null()), + ) + { + let null_value = ScalarValue::try_from(&self.return_type)?; + return Ok(ColumnarValue::Scalar(null_value)); + } + let input_empty = args.is_empty(); let input_all_scalar = args .iter() diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index baf4ef7795e7..f4b409b2cae6 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -950,9 +950,9 @@ select column1[2:4], column2[1:4], column3[3:4] from arrays; [[5, 6]] [NULL, 5.5, 6.6] [NULL, u] [[7, 8]] [7.7, 8.8, 9.9] [l, o] [[9, 10]] [10.1, NULL, 12.2] [t] -[] [13.3, 14.4, 15.5] [e, t] -[[13, 14]] [] [] -[[NULL, 18]] [16.6, 17.7, 18.8] [] +NULL [13.3, 14.4, 15.5] [e, t] +[[13, 14]] NULL [] +[[NULL, 18]] [16.6, 17.7, 18.8] NULL # multiple index with columns #2 (zero index) query ??? @@ -962,9 +962,9 @@ select column1[0:5], column2[0:3], column3[0:9] from arrays; [[3, 4], [5, 6]] [NULL, 5.5, 6.6] [i, p, NULL, u, m] [[5, 6], [7, 8]] [7.7, 8.8, 9.9] [d, NULL, l, o, r] [[7, NULL], [9, 10]] [10.1, NULL, 12.2] [s, i, t] -[] [13.3, 14.4, 15.5] [a, m, e, t] -[[11, 12], [13, 14]] [] [,] -[[15, 16], [NULL, 18]] [16.6, 17.7, 18.8] [] +NULL [13.3, 14.4, 15.5] [a, m, e, t] +[[11, 12], [13, 14]] NULL [,] +[[15, 16], [NULL, 18]] [16.6, 17.7, 18.8] NULL # TODO: support negative index # multiple index with columns #3 (negative index) @@ -1026,9 +1026,9 @@ select column1[2:4:2], column2[1:4:2], column3[3:4:2] from arrays; [[5, 6]] [NULL, 6.6] [NULL] [[7, 8]] [7.7, 9.9] [l] [[9, 10]] [10.1, 12.2] [t] -[] [13.3, 15.5] [e] -[[13, 14]] [] [] -[[NULL, 18]] [16.6, 18.8] [] +NULL [13.3, 15.5] [e] +[[13, 14]] NULL [] +[[NULL, 18]] [16.6, 18.8] NULL # multiple index with columns #2 (zero index) query ??? @@ -1038,9 +1038,9 @@ select column1[0:5:2], column2[0:3:2], column3[0:9:2] from arrays; [[3, 4]] [NULL, 6.6] [i, NULL, m] [[5, 6]] [7.7, 9.9] [d, l, r] [[7, NULL]] [10.1, 12.2] [s, t] -[] [13.3, 15.5] [a, e] -[[11, 12]] [] [,] -[[15, 16]] [16.6, 18.8] [] +NULL [13.3, 15.5] [a, e] +[[11, 12]] NULL [,] +[[15, 16]] [16.6, 18.8] NULL ### Array function tests @@ -1579,7 +1579,7 @@ select array_pop_back(column1) from arrayspop; [3, 4, 5] [6, 7, 8, NULL] [NULL, NULL] -[] +NULL [NULL, 10, 11] query ? @@ -1589,7 +1589,7 @@ select array_pop_back(arrow_cast(column1, 'LargeList(Int64)')) from arrayspop; [3, 4, 5] [6, 7, 8, NULL] [NULL, NULL] -[] +NULL [NULL, 10, 11] query ? @@ -1599,7 +1599,7 @@ select array_pop_back(column1) from large_arrayspop; [3, 4, 5] [6, 7, 8, NULL] [NULL, NULL] -[] +NULL [NULL, 10, 11] query ? @@ -1609,7 +1609,7 @@ select array_pop_back(arrow_cast(column1, 'LargeList(Int64)')) from large_arrays [3, 4, 5] [6, 7, 8, NULL] [NULL, NULL] -[] +NULL [NULL, 10, 11] ## array_pop_front (aliases: `list_pop_front`) @@ -1817,18 +1817,26 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, [1, 2, 3, 4] [h, e, l] # array_slice scalar function #8 (with NULL and positive number) -query error +query ?? select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); +---- +NULL NULL -query error +query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, 3); +---- +NULL NULL # array_slice scalar function #9 (with positive number and NULL) -query error +query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +---- +NULL NULL -query error +query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, NULL); +---- +NULL NULL # array_slice scalar function #10 (with zero-zero) query ?? @@ -1842,12 +1850,15 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, [] [] # array_slice scalar function #11 (with NULL-NULL) -query error +query ?? select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); +---- +NULL NULL -query error +query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); - +---- +NULL NULL # array_slice scalar function #12 (with zero and negative number) query ?? @@ -1861,18 +1872,26 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, [1, 2] [h, e, l] # array_slice scalar function #13 (with negative number and NULL) -query error +query ?? select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL); +---- +NULL NULL -query error +query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, NULL); +---- +NULL NULL # array_slice scalar function #14 (with NULL and negative number) -query error +query ?? select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); +---- +NULL NULL -query error +query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, -3); +---- +NULL NULL # array_slice scalar function #15 (with negative indexes) query ?? @@ -1982,9 +2001,9 @@ select array_slice(column1, column2, column3) from slices; [12, 13, 14, 15, 16, 17] [] [] -[] -[41, 42, 43, 44, 45, 46] -[55, 56, 57, 58, 59, 60] +NULL +NULL +NULL query ? select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices; @@ -1993,9 +2012,9 @@ select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) fr [12, 13, 14, 15, 16, 17] [] [] -[] -[41, 42, 43, 44, 45, 46] -[55, 56, 57, 58, 59, 60] +NULL +NULL +NULL # TODO: support NULLS in output instead of `[]` # array_slice with columns and scalars @@ -2006,9 +2025,9 @@ select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(col [2] [13, 14, 15, 16, 17] [12, 13, 14, 15] [] [] [21, 22, 23, NULL, 25] [] [33, 34] [] -[4, 5] [] [] -[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] -[5] [NULL, 54, 55, 56, 57, 58, 59, 60] [55] +[4, 5] NULL NULL +NULL [43, 44, 45, 46] NULL +NULL NULL [55] query ??? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices; @@ -2017,9 +2036,9 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), co [2] [13, 14, 15, 16, 17] [12, 13, 14, 15] [] [] [21, 22, 23, NULL, 25] [] [33, 34] [] -[4, 5] [] [] -[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] -[5] [NULL, 54, 55, 56, 57, 58, 59, 60] [55] +[4, 5] NULL NULL +NULL [43, 44, 45, 46] NULL +NULL NULL [55] # Test issue: https://github.com/apache/datafusion/issues/10425 # `from` may be larger than `to` and `stride` is positive @@ -2036,6 +2055,9 @@ select array_slice(a, -1, 2, 1), array_slice(a, -1, 2), query error DataFusion error: Error during planning: array_slice does not support zero arguments select array_slice(); +query error Failed to coerce arguments +select array_slice(3.5, NULL, NULL); + ## array_any_value (aliases: list_any_value) # Testing with empty arguments should result in an error