Skip to content
56 changes: 50 additions & 6 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! and return types of functions in DataFusion.

use std::fmt::Display;
use std::num::NonZeroUsize;

use crate::type_coercion::aggregates::NUMERICS;
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
Expand Down Expand Up @@ -236,9 +237,9 @@ pub enum ArrayFunctionSignature {
/// The first argument should be non-list or list, and the second argument should be List/LargeList.
/// The first argument's list dimension should be one dimension less than the second argument's list dimension.
ElementAndArray,
/// Specialized Signature for Array functions of the form (List/LargeList, Index)
/// The first argument should be List/LargeList/FixedSizedList, and the second argument should be Int64.
ArrayAndIndex,
/// Specialized Signature for Array functions of the form (List/LargeList, Index+)
/// The first argument should be List/LargeList/FixedSizedList, and the next n arguments should be Int64.
ArrayAndIndexes(NonZeroUsize),
/// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index)
ArrayAndElementAndOptionalIndex,
/// Specialized Signature for ArrayEmpty and similar functions
Expand All @@ -265,8 +266,12 @@ impl Display for ArrayFunctionSignature {
ArrayFunctionSignature::ElementAndArray => {
write!(f, "element, array")
}
ArrayFunctionSignature::ArrayAndIndex => {
write!(f, "array, index")
ArrayFunctionSignature::ArrayAndIndexes(count) => {
write!(f, "array")?;
for _ in 0..count.get() {
write!(f, ", index")?;
}
Ok(())
}
ArrayFunctionSignature::Array => {
write!(f, "array")
Expand Down Expand Up @@ -455,6 +460,15 @@ fn get_data_types(native_type: &NativeType) -> Vec<DataType> {
}
}

/// A function's behavior when the input is Null.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum NullHandling {
/// Null inputs are passed into the function implementation.
PassThrough,
/// Any Null input causes the function to return Null.
Propagate,
}

/// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function.
///
/// DataFusion will automatically coerce (cast) argument types to one of the supported
Expand All @@ -465,6 +479,8 @@ pub struct Signature {
pub type_signature: TypeSignature,
/// The volatility of the function. See [Volatility] for more information.
pub volatility: Volatility,
/// The Null handling of the function. See [NullHandling] for more information.
pub null_handling: NullHandling,
}

impl Signature {
Expand All @@ -473,20 +489,23 @@ impl Signature {
Signature {
type_signature,
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// An arbitrary number of arguments with the same type, from those listed in `common_types`.
pub fn variadic(common_types: Vec<DataType>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Variadic(common_types),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// User-defined coercion rules for the function.
pub fn user_defined(volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::UserDefined,
volatility,
null_handling: NullHandling::PassThrough,
}
}

Expand All @@ -495,6 +514,7 @@ impl Signature {
Self {
type_signature: TypeSignature::Numeric(arg_count),
volatility,
null_handling: NullHandling::PassThrough,
}
}

Expand All @@ -503,6 +523,7 @@ impl Signature {
Self {
type_signature: TypeSignature::String(arg_count),
volatility,
null_handling: NullHandling::PassThrough,
}
}

Expand All @@ -511,6 +532,7 @@ impl Signature {
Self {
type_signature: TypeSignature::VariadicAny,
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// A fixed number of arguments of the same type, from those listed in `valid_types`.
Expand All @@ -522,13 +544,15 @@ impl Signature {
Self {
type_signature: TypeSignature::Uniform(arg_count, valid_types),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// Exactly matches the types in `exact_types`, in order.
pub fn exact(exact_types: Vec<DataType>, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::Exact(exact_types),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// Target coerce types in order
Expand All @@ -539,6 +563,7 @@ impl Signature {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
null_handling: NullHandling::PassThrough,
}
}

Expand All @@ -547,13 +572,15 @@ impl Signature {
Self {
type_signature: TypeSignature::Comparable(arg_count),
volatility,
null_handling: NullHandling::PassThrough,
}
}

pub fn nullary(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::Nullary,
volatility,
null_handling: NullHandling::PassThrough,
}
}

Expand All @@ -562,13 +589,15 @@ impl Signature {
Signature {
type_signature: TypeSignature::Any(arg_count),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// Any one of a list of [TypeSignature]s.
pub fn one_of(type_signatures: Vec<TypeSignature>, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::OneOf(type_signatures),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// Specialized Signature for ArrayAppend and similar functions
Expand All @@ -578,6 +607,7 @@ impl Signature {
ArrayFunctionSignature::ArrayAndElement,
),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// Specialized Signature for Array functions with an optional index
Expand All @@ -587,6 +617,7 @@ impl Signature {
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex,
),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// Specialized Signature for ArrayPrepend and similar functions
Expand All @@ -596,24 +627,37 @@ impl Signature {
ArrayFunctionSignature::ElementAndArray,
),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// Specialized Signature for ArrayElement and similar functions
pub fn array_and_index(volatility: Volatility) -> Self {
Self::array_and_indexes(volatility, NonZeroUsize::new(1).expect("1 is non-zero"))
}
/// Specialized Signature for ArraySlice and similar functions
pub fn array_and_indexes(volatility: Volatility, count: NonZeroUsize) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndex,
ArrayFunctionSignature::ArrayAndIndexes(count),
),
volatility,
null_handling: NullHandling::PassThrough,
}
}
/// Specialized Signature for ArrayEmpty and similar functions
pub fn array(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
volatility,
null_handling: NullHandling::PassThrough,
}
}

/// Returns an equivalent Signature, with null_handling set to the input.
pub fn with_null_handling(mut self, null_handling: NullHandling) -> Self {
self.null_handling = null_handling;
self
}
}

#[cfg(test)]
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue;
pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
pub use datafusion_expr_common::operator::Operator;
pub use datafusion_expr_common::signature::{
ArrayFunctionSignature, Signature, TypeSignature, TypeSignatureClass, Volatility,
TIMEZONE_WILDCARD,
ArrayFunctionSignature, NullHandling, Signature, TypeSignature, TypeSignatureClass,
Volatility, TIMEZONE_WILDCARD,
};
pub use datafusion_expr_common::type_coercion::binary;
pub use expr::{
Expand Down
13 changes: 10 additions & 3 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,20 @@ fn get_valid_types(
ArrayFunctionSignature::ElementAndArray => {
array_append_or_prepend_valid_types(current_types, false)?
}
ArrayFunctionSignature::ArrayAndIndex => {
if current_types.len() != 2 {
ArrayFunctionSignature::ArrayAndIndexes(count) => {
if current_types.len() != count.get() + 1 {
return Ok(vec![vec![]]);
}
array(&current_types[0]).map_or_else(
|| vec![vec![]],
|array_type| vec![vec![array_type, DataType::Int64]],
|array_type| {
let mut inner = Vec::with_capacity(count.get() + 1);
inner.push(array_type);
for _ in 0..count.get() {
inner.push(DataType::Int64);
}
vec![inner]
},
)
}
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
Expand Down
29 changes: 24 additions & 5 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ use datafusion_common::cast::as_list_array;
use datafusion_common::{
exec_err, internal_datafusion_err, plan_err, DataFusionError, Result,
};
use datafusion_expr::Expr;
use datafusion_expr::{ArrayFunctionSignature, Expr, TypeSignature};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ColumnarValue, Documentation, NullHandling, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
use std::num::NonZeroUsize;
use std::sync::Arc;

use crate::utils::make_scalar_function;
Expand Down Expand Up @@ -330,7 +331,27 @@ pub(super) struct ArraySlice {
impl ArraySlice {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndexes(
NonZeroUsize::new(1).expect("1 is non-zero"),
),
),
TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndexes(
NonZeroUsize::new(2).expect("2 is non-zero"),
),
),
TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndexes(
NonZeroUsize::new(3).expect("3 is non-zero"),
),
),
],
Volatility::Immutable,
)
.with_null_handling(NullHandling::Propagate),
aliases: vec![String::from("list_slice")],
}
}
Expand Down Expand Up @@ -430,8 +451,6 @@ fn array_slice_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
}
LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let from_array = as_int64_array(&args[1])?;
let to_array = as_int64_array(&args[2])?;
general_array_slice::<i64>(array, from_array, to_array, stride)
}
_ => exec_err!("array_slice does not support type: {:?}", array_data_type),
Expand Down
5 changes: 3 additions & 2 deletions datafusion/functions-nested/src/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ use datafusion_common::cast::{
};
use datafusion_common::{exec_err, Result};
use datafusion_expr::{
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
ArrayFunctionSignature, ColumnarValue, Documentation, NullHandling, ScalarUDFImpl,
Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
Expand Down Expand Up @@ -80,6 +80,7 @@ impl Flatten {
ArrayFunctionSignature::RecursiveArray,
),
volatility: Volatility::Immutable,
null_handling: NullHandling::PassThrough,
},
aliases: vec![],
}
Expand Down
10 changes: 10 additions & 0 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
use datafusion_expr::{
expr_vec_fmt, ColumnarValue, Expr, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
};
use datafusion_expr_common::signature::NullHandling;

/// Physical expression of a scalar function
#[derive(Eq, PartialEq, Hash)]
Expand Down Expand Up @@ -186,6 +187,15 @@ impl PhysicalExpr for ScalarFunctionExpr {
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?;

if self.fun.signature().null_handling == NullHandling::Propagate
&& args.iter().any(
|arg| matches!(arg, ColumnarValue::Scalar(scalar) if scalar.is_null()),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super confident about this check, how should ColumnarValue::Arrays be treated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I understand this now. If the function is called with a single set of arguments then each arg will be a ColumnarValue::Scalar. However, if the function is called on a batch of arguments, then each arg will be a ColumnarValue::Array of all the arguments. So this does not work in the batch case.

What we'd really like is to identify all indexes, i, s.t. one of the args at index i is Null. Then somehow skip all rows at the identified indexes and immediately return Null for those. That seems a little tricky because it looks like we pass the entire ArrayRef to the function implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I understand this now. If the function is called with a single set of arguments then each arg will be a ColumnarValue::Scalar. However, if the function is called on a batch of arguments, then each arg will be a ColumnarValue::Array of all the arguments. So this does not work in the batch case.

What we'd really like is to identify all indexes, i, s.t. one of the args at index i is Null. Then somehow skip all rows at the identified indexes and immediately return Null for those. That seems a little tricky because it looks like we pass the entire ArrayRef to the function implementation.

I don't think we need to peek the null for column case, they should be specific logic handled for each function. For scalar case, since most of the scalar function returns null if any one of args is null, it is beneficial to introduce another null handling method. It is just convenient method nice to have but not the must have solution for null handling since they can be handled in 'invoke' as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just convenient method nice to have but not the must have solution for null handling since they can be handled in 'invoke' as well.

If someone forgets to handle nulls in invoke, then don't we run the risk of accidentally returning different results depending on if the function was called with scalar arguments or with a batch of arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I'm not sure I understand the point of the check here, if the invoke implementation also has to handle nulls, but maybe I'm misunderstanding what you're saying.

Copy link
Contributor

@jayzhan211 jayzhan211 Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalar_function(null, null, ...) and scalar_function(column_contains_null, ...). We can only handling nulls but not column_contains_null because we don't now the data in the column

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, I don't think I understand your response. So I should leave this code block here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

)
{
let null_value = ScalarValue::try_from(&self.return_type)?;
return Ok(ColumnarValue::Scalar(null_value));
}

let input_empty = args.is_empty();
let input_all_scalar = args
.iter()
Expand Down
Loading