diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index e77a072a84f38..55a8843394b51 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -16,31 +16,12 @@ // under the License. use crate::signature::TypeSignature; -use arrow::datatypes::{ - DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, - DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, -}; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::{internal_err, plan_err, Result}; -pub static STRINGS: &[DataType] = - &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]; - -pub static SIGNED_INTEGERS: &[DataType] = &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, -]; - -pub static UNSIGNED_INTEGERS: &[DataType] = &[ - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, -]; - +// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures +// see https://github.com/apache/datafusion/issues/18092 pub static INTEGERS: &[DataType] = &[ DataType::Int8, DataType::Int16, @@ -65,24 +46,6 @@ pub static NUMERICS: &[DataType] = &[ DataType::Float64, ]; -pub static TIMESTAMPS: &[DataType] = &[ - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, None), -]; - -pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; - -pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary]; - -pub static TIMES: &[DataType] = &[ - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), -]; - /// Validate the length of `input_fields` matches the `signature` for `agg_fun`. /// /// This method DOES NOT validate the argument fields - only that (at least one, @@ -144,260 +107,3 @@ pub fn check_arg_count( } Ok(()) } - -/// Function return type of a sum -pub fn sum_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int64 => Ok(DataType::Int64), - DataType::UInt64 => Ok(DataType::UInt64), - DataType::Float64 => Ok(DataType::Float64), - DataType::Decimal32(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal32(new_precision, *scale)) - } - DataType::Decimal64(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal64(new_precision, *scale)) - } - DataType::Decimal128(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+10), s) - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal128(new_precision, *scale)) - } - DataType::Decimal256(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+10), s) - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal256(new_precision, *scale)) - } - other => plan_err!("SUM does not support type \"{other:?}\""), - } -} - -/// Function return type of variance -pub fn variance_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("VAR does not support {arg_type}") - } -} - -/// Function return type of covariance -pub fn covariance_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("COVAR does not support {arg_type}") - } -} - -/// Function return type of correlation -pub fn correlation_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("CORR does not support {arg_type}") - } -} - -/// Function return type of an average -pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result { - match arg_type { - DataType::Decimal32(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal32(new_precision, new_scale)) - } - DataType::Decimal64(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal64(new_precision, new_scale)) - } - DataType::Decimal128(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal128(new_precision, new_scale)) - } - DataType::Decimal256(precision, scale) => { - // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal256(new_precision, new_scale)) - } - DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), - arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), - DataType::Dictionary(_, dict_value_type) => { - avg_return_type(func_name, dict_value_type.as_ref()) - } - other => plan_err!("{func_name} does not support {other:?}"), - } -} - -/// Internal sum type of an average -pub fn avg_sum_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Decimal32(precision, scale) => { - // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal32(new_precision, *scale)) - } - DataType::Decimal64(precision, scale) => { - // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal64(new_precision, *scale)) - } - DataType::Decimal128(precision, scale) => { - // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal128(new_precision, *scale)) - } - DataType::Decimal256(precision, scale) => { - // In Spark the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal256(new_precision, *scale)) - } - DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), - arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), - DataType::Dictionary(_, dict_value_type) => { - avg_sum_type(dict_value_type.as_ref()) - } - other => plan_err!("AVG does not support {other:?}"), - } -} - -pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - DataType::Dictionary(_, dict_value_type) => { - is_sum_support_arg_type(dict_value_type.as_ref()) - } - _ => matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) - ), - } -} - -pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - DataType::Dictionary(_, dict_value_type) => { - is_avg_support_arg_type(dict_value_type.as_ref()) - } - _ => matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) - ), - } -} - -pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_integer_arg_type(arg_type: &DataType) -> bool { - arg_type.is_integer() -} - -pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result> { - // Supported types smallint, int, bigint, real, double precision, decimal, or interval - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - fn coerced_type(func_name: &str, data_type: &DataType) -> Result { - match &data_type { - DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), - DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), - DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), - DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), - d if d.is_numeric() => Ok(DataType::Float64), - DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), - DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()), - _ => { - plan_err!( - "The function {:?} does not support inputs of type {}.", - func_name, - data_type - ) - } - } - } - Ok(vec![coerced_type(func_name, &arg_types[0])?]) -} -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_variance_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = variance_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(variance_return_type(&data_type).is_err()); - Ok(()) - } - - #[test] - fn test_sum_return_data_type() -> Result<()> { - let data_type = DataType::Decimal128(10, 5); - let result_type = sum_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(20, 5), result_type); - - let data_type = DataType::Decimal128(36, 10); - let result_type = sum_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(38, 10), result_type); - Ok(()) - } - - #[test] - fn test_covariance_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = covariance_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(covariance_return_type(&data_type).is_err()); - Ok(()) - } - - #[test] - fn test_correlation_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = correlation_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(correlation_return_type(&data_type).is_err()); - Ok(()) - } -} diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 41bc645058079..8609afeae6018 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -22,13 +22,15 @@ use std::any::Any; use arrow::datatypes::{ - DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, - DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, + DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, }; +use datafusion_common::plan_err; use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; -use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS}; +use crate::type_coercion::aggregates::NUMERICS; use crate::Volatility::Immutable; use crate::{ expr::AggregateFunction, @@ -488,8 +490,61 @@ impl AggregateUDFImpl for Avg { &self.signature } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [args] = take_function_args(self.name(), arg_types)?; + + // Supported types smallint, int, bigint, real, double precision, decimal, or interval + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + fn coerced_type(data_type: &DataType) -> Result { + match &data_type { + DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), + DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), + DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), + DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), + d if d.is_numeric() => Ok(DataType::Float64), + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), + DataType::Dictionary(_, v) => coerced_type(v.as_ref()), + _ => { + plan_err!("Avg does not support inputs of type {data_type}.") + } + } + } + Ok(vec![coerced_type(args)?]) + } + fn return_type(&self, arg_types: &[DataType]) -> Result { - avg_return_type(self.name(), &arg_types[0]) + match &arg_types[0] { + DataType::Decimal32(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal32(new_precision, new_scale)) + } + DataType::Decimal64(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal64(new_precision, new_scale)) + } + DataType::Decimal128(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal128(new_precision, new_scale)) + } + DataType::Decimal256(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal256(new_precision, new_scale)) + } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), + _ => Ok(DataType::Float64), + } } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { @@ -503,8 +558,4 @@ impl AggregateUDFImpl for Avg { fn aliases(&self) -> &[String] { &self.aliases } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - coerce_avg_type(self.name(), arg_types) - } } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index d007163e7c08f..11960779ed18c 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -27,14 +27,15 @@ use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, - UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, - DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, + DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, }; +use datafusion_common::plan_err; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type}; use datafusion_expr::utils::format_state_name; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ @@ -125,8 +126,61 @@ impl AggregateUDFImpl for Avg { &self.signature } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [args] = take_function_args(self.name(), arg_types)?; + + // Supported types smallint, int, bigint, real, double precision, decimal, or interval + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + fn coerced_type(data_type: &DataType) -> Result { + match &data_type { + DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), + DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), + DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), + DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), + d if d.is_numeric() => Ok(DataType::Float64), + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), + DataType::Dictionary(_, v) => coerced_type(v.as_ref()), + _ => { + plan_err!("Avg does not support inputs of type {data_type}.") + } + } + } + Ok(vec![coerced_type(args)?]) + } + fn return_type(&self, arg_types: &[DataType]) -> Result { - avg_return_type(self.name(), &arg_types[0]) + match &arg_types[0] { + DataType::Decimal32(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal32(new_precision, new_scale)) + } + DataType::Decimal64(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal64(new_precision, new_scale)) + } + DataType::Decimal128(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal128(new_precision, new_scale)) + } + DataType::Decimal256(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal256(new_precision, new_scale)) + } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), + _ => Ok(DataType::Float64), + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -452,11 +506,6 @@ impl AggregateUDFImpl for Avg { ReversedUDAF::Identical } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [args] = take_function_args(self.name(), arg_types)?; - coerce_avg_type(self.name(), std::slice::from_ref(args)) - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 4f282301ce5bd..31d776950d60a 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -207,13 +207,7 @@ mod tests { #[test] fn test_no_duplicate_name() -> Result<()> { let mut names = HashSet::new(); - let migrated_functions = ["array_agg", "count", "max", "min"]; for func in all_default_aggregate_functions() { - // TODO: remove this - // These functions are in intermediate migration state, skip them - if migrated_functions.contains(&func.name().to_lowercase().as_str()) { - continue; - } assert!( names.insert(func.name().to_string().to_lowercase()), "duplicate function name: {}", diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index 329d8aa5ab178..1ba6ad5ce0d49 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -40,39 +40,28 @@ use std::hash::Hash; use std::ops::Range; use std::sync::{Arc, LazyLock}; -get_or_init_udwf!( +define_udwf_and_expr!( First, first_value, - "returns the first value in the window frame", + [arg], + "Returns the first value in the window frame", NthValue::first ); -get_or_init_udwf!( +define_udwf_and_expr!( Last, last_value, - "returns the last value in the window frame", + [arg], + "Returns the last value in the window frame", NthValue::last ); get_or_init_udwf!( NthValue, nth_value, - "returns the nth value in the window frame", + "Returns the nth value in the window frame", NthValue::nth ); -/// Create an expression to represent the `first_value` window function -/// -pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr { - first_value_udwf().call(vec![arg]) -} - -/// Create an expression to represent the `last_value` window function -/// -pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr { - last_value_udwf().call(vec![arg]) -} - /// Create an expression to represent the `nth_value` window function -/// pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr { nth_value_udwf().call(vec![arg, n.lit()]) } diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs index d188db3bbf59e..008caaa848aab 100644 --- a/datafusion/functions-window/src/ntile.rs +++ b/datafusion/functions-window/src/ntile.rs @@ -25,8 +25,7 @@ use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_datafusion_err, exec_err, Result}; use datafusion_expr::{ - Documentation, Expr, LimitEffect, PartitionEvaluator, Signature, Volatility, - WindowUDFImpl, + Documentation, LimitEffect, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -37,16 +36,13 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -get_or_init_udwf!( +define_udwf_and_expr!( Ntile, ntile, - "integer ranging from 1 to the argument value, dividing the partition as equally as possible" + [arg], + "Integer ranging from 1 to the argument value, dividing the partition as equally as possible." ); -pub fn ntile(arg: Expr) -> Expr { - ntile_udwf().call(vec![arg]) -} - #[user_doc( doc_section(label = "Ranking Functions"), description = "Integer ranging from 1 to the argument value, dividing the partition as equally as possible", diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index a22561ba8b9ca..65736815fec5c 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -25,41 +25,38 @@ use arrow::array::{ use arrow::compute::sum; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::take_function_args; -use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::coerce_avg_type; use datafusion_expr::utils::format_state_name; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl, EmitTo, - GroupsAccumulator, ReversedUDAF, Signature, + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, }; use std::{any::Any, sync::Arc}; -use DataType::*; /// AVG aggregate expression /// Spark average aggregate expression. Differs from standard DataFusion average aggregate /// in that it uses an `i64` for the count (DataFusion version uses `u64`); also there is ANSI mode /// support planned in the future for Spark version. +// TODO: see if can deduplicate with DF version +// https://github.com/apache/datafusion/issues/17964 #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SparkAvg { - name: String, signature: Signature, - input_data_type: DataType, - result_data_type: DataType, +} + +impl Default for SparkAvg { + fn default() -> Self { + Self::new() + } } impl SparkAvg { /// Implement AVG aggregate function - pub fn new(name: impl Into, data_type: DataType) -> Self { - let result_data_type = avg_return_type("avg", &data_type).unwrap(); - + pub fn new() -> Self { Self { - name: name.into(), signature: Signature::user_defined(Immutable), - input_data_type: data_type, - result_data_type, } } } @@ -69,63 +66,87 @@ impl AggregateUDFImpl for SparkAvg { self } - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [args] = take_function_args(self.name(), arg_types)?; + + fn coerced_type(data_type: &DataType) -> Result { + match &data_type { + d if d.is_numeric() => Ok(DataType::Float64), + DataType::Dictionary(_, v) => coerced_type(v.as_ref()), + _ => { + plan_err!("Avg does not support inputs of type {data_type}.") + } + } + } + Ok(vec![coerced_type(args)?]) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!("DistinctAvgAccumulator"); + } + + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + // instantiate specialized accumulator based for the type - match (&self.input_data_type, &self.result_data_type) { - (Float64, Float64) => Ok(Box::::default()), - _ => not_impl_err!( - "AvgAccumulator for ({} --> {})", - self.input_data_type, - self.result_data_type - ), + match (&data_type, &acc_args.return_type()) { + (DataType::Float64, DataType::Float64) => { + Ok(Box::::default()) + } + (dt, return_type) => { + not_impl_err!("AvgAccumulator for ({dt} --> {return_type})") + } } } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Arc::new(Field::new( - format_state_name(&self.name, "sum"), - self.input_data_type.clone(), + format_state_name(self.name(), "sum"), + args.input_fields[0].data_type().clone(), true, )), Arc::new(Field::new( - format_state_name(&self.name, "count"), - Int64, + format_state_name(self.name(), "count"), + DataType::Int64, true, )), ]) } fn name(&self) -> &str { - &self.name + "avg" } fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Identical } - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct } fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { + let data_type = args.exprs[0].data_type(args.schema)?; + // instantiate specialized accumulator based for the type - match (&self.input_data_type, &self.result_data_type) { - (Float64, Float64) => { + match (&data_type, args.return_type()) { + (DataType::Float64, DataType::Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - &self.input_data_type, + args.return_field.data_type(), |sum: f64, count: i64| Ok(sum / count as f64), ))) } - - _ => not_impl_err!( - "AvgGroupsAccumulator for ({} --> {})", - self.input_data_type, - self.result_data_type - ), + (dt, return_type) => { + not_impl_err!("AvgGroupsAccumulator for ({dt} --> {return_type})") + } } } @@ -136,15 +157,6 @@ impl AggregateUDFImpl for SparkAvg { fn signature(&self) -> &Signature { &self.signature } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - avg_return_type(self.name(), &arg_types[0]) - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [arg] = take_function_args(self.name(), arg_types)?; - coerce_avg_type(self.name(), std::slice::from_ref(arg)) - } } /// An accumulator to compute the average diff --git a/datafusion/spark/src/function/aggregate/mod.rs b/datafusion/spark/src/function/aggregate/mod.rs index 54001d28da6b4..d765d9c82f068 100644 --- a/datafusion/spark/src/function/aggregate/mod.rs +++ b/datafusion/spark/src/function/aggregate/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; use datafusion_expr::AggregateUDF; use std::sync::Arc; @@ -26,11 +25,9 @@ pub mod expr_fn { export_functions!((avg, "Returns the average value of a given column", arg1)); } +// TODO: try use something like datafusion_functions_aggregate::create_func!() pub fn avg() -> Arc { - Arc::new(AggregateUDF::new_from_impl(avg::SparkAvg::new( - "avg", - DataType::Float64, - ))) + Arc::new(AggregateUDF::new_from_impl(avg::SparkAvg::new())) } pub fn functions() -> Vec> {