From d162015fe2f8bf8f538ddc1af08f91f043cf329a Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 24 May 2024 09:29:51 +0800 Subject: [PATCH 1/7] introduce median udaf Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 1 + datafusion/expr/src/tree_node.rs | 2 +- .../expr/src/type_coercion/aggregates.rs | 3 + datafusion/expr/src/udaf.rs | 29 +- datafusion/functions-aggregate/Cargo.toml | 1 + datafusion/functions-aggregate/src/lib.rs | 3 + datafusion/functions-aggregate/src/median.rs | 310 ++++++++++++++++++ .../optimizer/src/analyzer/type_coercion.rs | 2 +- .../src/single_distinct_to_groupby.rs | 50 +++ 9 files changed, 397 insertions(+), 4 deletions(-) create mode 100644 datafusion/functions-aggregate/src/median.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 99af80bf9df2..e659e62d7ac7 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1287,6 +1287,7 @@ name = "datafusion-functions-aggregate" version = "38.0.0" dependencies = [ "arrow", + "arrow-schema", "datafusion-common", "datafusion-execution", "datafusion-expr", diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 31ca4c40942b..c5f1694c1138 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -332,7 +332,7 @@ impl TreeNode for Expr { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( fun, new_args, - false, + distinct, new_filter, new_order_by, null_treatment, diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 57c0b6f4edc5..e8cd6740be2c 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -355,6 +355,9 @@ pub fn check_arg_count( ); } } + TypeSignature::UserDefined => { + // User-defined functions are not validated here + } _ => { return internal_err!( "Aggregate functions do not support this {signature:?}" diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 4fd8d51679f0..b8df94c950fe 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -193,9 +193,11 @@ impl AggregateUDF { self.inner.create_groups_accumulator() } - pub fn coerce_types(&self, _args: &[DataType]) -> Result> { - not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) + /// See [`AggregateUDFImpl::coerce_types`] for more details. + pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) } + /// Do the function rewrite /// /// See [`AggregateUDFImpl::simplify`] for more details. @@ -389,6 +391,29 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::NotSupported } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most + /// UDFs should return one of the other variants of `TypeSignature` which handle common + /// cases + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` + /// to ensure the argument was cast to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } } pub enum ReversedUDAF { diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index f97647565364..696bbaece9e6 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -39,6 +39,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } +arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index e76a43e39899..3e80174eec33 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -57,6 +57,7 @@ pub mod macros; pub mod covariance; pub mod first_last; +pub mod median; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; @@ -68,6 +69,7 @@ use std::sync::Arc; pub mod expr_fn { pub use super::covariance::covar_samp; pub use super::first_last::first_value; + pub use super::median::median; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -76,6 +78,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { first_last::first_value_udaf(), covariance::covar_samp_udaf(), covariance::covar_pop_udaf(), + median::median_udaf(), ]; functions.into_iter().try_for_each(|udf| { diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs new file mode 100644 index 000000000000..fa5a302833bb --- /dev/null +++ b/datafusion/functions-aggregate/src/median.rs @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashSet; +use std::fmt::Formatter; +use std::{fmt::Debug, sync::Arc}; + +use arrow::array::{downcast_integer, ArrowNumericType}; +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{ + DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type, + Float64Type, + }, +}; + +use arrow::array::Array; +use arrow::array::ArrowNativeTypeOp; +use arrow::datatypes::ArrowNativeType; + +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, + Signature, Volatility, +}; +use datafusion_physical_expr_common::aggregate::utils::Hashable; + +make_udaf_expr_and_func!( + Median, + median, + expression, + "Computes the median of a set of numbers", + median_udaf +); + +pub struct Median { + signature: Signature, + aliases: Vec, +} + +impl Debug for Median { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Median") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Median { + fn default() -> Self { + Self::new() + } +} + +impl Median { + pub fn new() -> Self { + Self { + aliases: vec!["median".to_string()], + signature: Signature::user_defined(Volatility::Immutable), // signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) + } + } +} + +impl AggregateUDFImpl for Median { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MEDIAN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!("Median takes exactly one argument"); + } + + if arg_types[0].is_numeric() { + Ok(vec![arg_types[0].clone()]) + } else { + exec_err!("Median only accepts numeric types") + } + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + //Intermediate state is a list of the elements we have collected so far + let field = Field::new("item", args.input_type.clone(), true); + let state_name = if args.is_distinct { + "distinct_median" + } else { + "median" + }; + + Ok(vec![Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + )]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + macro_rules! helper { + ($t:ty, $dt:expr) => { + if acc_args.is_distinct { + Ok(Box::new(DistinctMedianAccumulator::<$t> { + data_type: $dt.clone(), + distinct_values: HashSet::new(), + })) + } else { + Ok(Box::new(MedianAccumulator::<$t> { + data_type: $dt.clone(), + all_values: vec![], + })) + } + }; + } + + let dt = acc_args.input_type; + downcast_integer! { + dt => (helper, dt), + DataType::Float16 => helper!(Float16Type, dt), + DataType::Float32 => helper!(Float32Type, dt), + DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + _ => Err(DataFusionError::NotImplemented(format!( + "MedianAccumulator not supported for {} with {}", + // acc_args.name, + "name", + dt, + ))), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// The median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct MedianAccumulator { + data_type: DataType, + all_values: Vec, +} + +impl std::fmt::Debug for MedianAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MedianAccumulator({})", self.data_type) + } +} + +impl Accumulator for MedianAccumulator { + fn state(&mut self) -> Result> { + let all_values = self + .all_values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) + .collect::>>()?; + + let arr = ScalarValue::new_list(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.all_values.reserve(values.len() - values.null_count()); + self.all_values.extend(values.iter().flatten()); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let d = std::mem::take(&mut self.all_values); + let median = calculate_median::(d); + ScalarValue::new_primitive::(median, &self.data_type) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.all_values.capacity() * std::mem::size_of::() + } +} + +/// The distinct median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct DistinctMedianAccumulator { + data_type: DataType, + distinct_values: HashSet>, +} + +impl std::fmt::Debug for DistinctMedianAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctMedianAccumulator({})", self.data_type) + } +} + +impl Accumulator for DistinctMedianAccumulator { + fn state(&mut self) -> Result> { + let all_values = self + .distinct_values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(x.0), &self.data_type)) + .collect::>>()?; + + let arr = ScalarValue::new_list(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.distinct_values.insert(Hashable(array.value(idx))); + } + } + None => array.values().iter().for_each(|x| { + self.distinct_values.insert(Hashable(*x)); + }), + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let d = std::mem::take(&mut self.distinct_values) + .into_iter() + .map(|v| v.0) + .collect::>(); + let median = calculate_median::(d); + ScalarValue::new_primitive::(median, &self.data_type) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.distinct_values.capacity() * std::mem::size_of::() + } +} + +fn calculate_median( + mut values: Vec, +) -> Option { + let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); + + let len = values.len(); + if len == 0 { + None + } else if len % 2 == 0 { + let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); + let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); + let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); + Some(median) + } else { + let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp); + Some(*median) + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3d08bd6c7e42..69be344cb753 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -402,7 +402,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { expr::AggregateFunction::new_udf( fun, new_expr, - false, + distinct, filter, order_by, null_treatment, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 4b1f9a0d1401..27449c8dd5c4 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -257,6 +257,56 @@ impl OptimizerRule for SingleDistinctToGroupBy { ))) } } + Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(udf), + args, + distinct, + .. + }) => { + if distinct { + if args.len() != 1 { + return internal_err!("DISTINCT aggregate should have exactly one argument"); + } + let mut args = args; + let arg = args.swap_remove(0); + + if group_fields_set.insert(arg.display_name()?) { + inner_group_exprs + .push(arg.alias(SINGLE_DISTINCT_ALIAS)); + } + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + udf, + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + None, + None, + None, + ))) + // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation + } else { + index += 1; + let alias_str = format!("alias{}", index); + inner_aggr_exprs.push( + Expr::AggregateFunction(AggregateFunction::new_udf( + udf.clone(), + args, + false, + None, + None, + None, + )) + .alias(&alias_str), + ); + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + udf.clone(), + vec![col(&alias_str)], + false, + None, + None, + None, + ))) + } + } _ => Ok(aggr_expr), }) .collect::>>()?; From c9f6386da6c74588e149fc7224b27e93c368682d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 24 May 2024 19:22:27 +0800 Subject: [PATCH 2/7] rm agg median Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 3 ++- datafusion/expr/src/aggregate_function.rs | 9 +-------- datafusion/expr/src/expr_fn.rs | 12 ------------ datafusion/expr/src/type_coercion/aggregates.rs | 6 +++--- datafusion/physical-expr/src/aggregate/build_in.rs | 6 ------ datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 --- datafusion/proto/src/generated/prost.rs | 4 +--- datafusion/proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 2 -- datafusion/proto/src/physical_plan/to_proto.rs | 4 +--- 11 files changed, 9 insertions(+), 43 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index d4626134acbf..e1656a22b1a4 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,10 +50,11 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, count, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, + avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null, sum}; +use datafusion_functions_aggregate::expr_fn::median; use async_trait::async_trait; diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 0a7607498c61..f251969ca618 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -43,8 +43,6 @@ pub enum AggregateFunction { Max, /// Average Avg, - /// Median - Median, /// Approximate distinct function ApproxDistinct, /// Aggregation into an array @@ -114,7 +112,6 @@ impl AggregateFunction { Min => "MIN", Max => "MAX", Avg => "AVG", - Median => "MEDIAN", ApproxDistinct => "APPROX_DISTINCT", ArrayAgg => "ARRAY_AGG", FirstValue => "FIRST_VALUE", @@ -168,7 +165,6 @@ impl FromStr for AggregateFunction { "count" => AggregateFunction::Count, "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, - "median" => AggregateFunction::Median, "min" => AggregateFunction::Min, "sum" => AggregateFunction::Sum, "array_agg" => AggregateFunction::ArrayAgg, @@ -275,9 +271,7 @@ impl AggregateFunction { AggregateFunction::ApproxPercentileContWithWeight => { Ok(coerced_data_types[0].clone()) } - AggregateFunction::ApproxMedian | AggregateFunction::Median => { - Ok(coerced_data_types[0].clone()) - } + AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::FirstValue | AggregateFunction::LastValue @@ -335,7 +329,6 @@ impl AggregateFunction { | AggregateFunction::VariancePop | AggregateFunction::Stddev | AggregateFunction::StddevPop - | AggregateFunction::Median | AggregateFunction::ApproxMedian | AggregateFunction::FirstValue | AggregateFunction::LastValue => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 2a2bb75f1884..8c9d3c7885b0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -296,18 +296,6 @@ pub fn approx_distinct(expr: Expr) -> Expr { )) } -/// Calculate the median for `expr`. -pub fn median(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Median, - vec![expr], - false, - None, - None, - None, - )) -} - /// Calculate an approximation of the median for `expr`. pub fn approx_median(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index e8cd6740be2c..8c365fcfda77 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -283,9 +283,9 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::Median - | AggregateFunction::FirstValue - | AggregateFunction::LastValue => Ok(input_types.to_vec()), + AggregateFunction::FirstValue | AggregateFunction::LastValue => { + Ok(input_types.to_vec()) + } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 145e7feadf8c..18252ea370eb 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -332,12 +332,6 @@ pub fn create_aggregate_expr( "APPROX_MEDIAN(DISTINCT) aggregations are not available" ); } - (AggregateFunction::Median, distinct) => Arc::new(expressions::Median::new( - input_phy_exprs[0].clone(), - name, - data_type, - distinct, - )), (AggregateFunction::FirstValue, _) => Arc::new( expressions::FirstValue::new( input_phy_exprs[0].clone(), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 73e751c616ac..434ec9f81f15 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -551,7 +551,7 @@ enum AggregateFunction { APPROX_MEDIAN = 15; APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; - MEDIAN = 18; + // MEDIAN = 18; BIT_AND = 19; BIT_OR = 20; BIT_XOR = 21; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 77ba0808fb77..86a5975c8bb8 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -437,7 +437,6 @@ impl serde::Serialize for AggregateFunction { Self::ApproxMedian => "APPROX_MEDIAN", Self::ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Self::Grouping => "GROUPING", - Self::Median => "MEDIAN", Self::BitAnd => "BIT_AND", Self::BitOr => "BIT_OR", Self::BitXor => "BIT_XOR", @@ -483,7 +482,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "APPROX_MEDIAN", "APPROX_PERCENTILE_CONT_WITH_WEIGHT", "GROUPING", - "MEDIAN", "BIT_AND", "BIT_OR", "BIT_XOR", @@ -558,7 +556,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "APPROX_MEDIAN" => Ok(AggregateFunction::ApproxMedian), "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => Ok(AggregateFunction::ApproxPercentileContWithWeight), "GROUPING" => Ok(AggregateFunction::Grouping), - "MEDIAN" => Ok(AggregateFunction::Median), "BIT_AND" => Ok(AggregateFunction::BitAnd), "BIT_OR" => Ok(AggregateFunction::BitOr), "BIT_XOR" => Ok(AggregateFunction::BitXor), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index a175987f1994..cb2de710075a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2848,7 +2848,7 @@ pub enum AggregateFunction { ApproxMedian = 15, ApproxPercentileContWithWeight = 16, Grouping = 17, - Median = 18, + /// MEDIAN = 18; BitAnd = 19, BitOr = 20, BitXor = 21, @@ -2895,7 +2895,6 @@ impl AggregateFunction { "APPROX_PERCENTILE_CONT_WITH_WEIGHT" } AggregateFunction::Grouping => "GROUPING", - AggregateFunction::Median => "MEDIAN", AggregateFunction::BitAnd => "BIT_AND", AggregateFunction::BitOr => "BIT_OR", AggregateFunction::BitXor => "BIT_XOR", @@ -2937,7 +2936,6 @@ impl AggregateFunction { Some(Self::ApproxPercentileContWithWeight) } "GROUPING" => Some(Self::Grouping), - "MEDIAN" => Some(Self::Median), "BIT_AND" => Some(Self::BitAnd), "BIT_OR" => Some(Self::BitOr), "BIT_XOR" => Some(Self::BitXor), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b6f72f6773a2..00c62fc32b98 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -450,7 +450,6 @@ impl From for AggregateFunction { } protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, - protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 91f7411e911a..f2ee679ac129 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -386,7 +386,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { } AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, - AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, AggregateFunction::NthValue => Self::NthValueAgg, @@ -697,7 +696,6 @@ pub fn serialize_expr( protobuf::AggregateFunction::ApproxMedian } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, AggregateFunction::FirstValue => { protobuf::AggregateFunction::FirstValueAgg } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index c6b94a934f23..d3badee3efff 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -27,7 +27,7 @@ use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, InListExpr, IsNotNullExpr, - IsNullExpr, LastValue, Literal, Max, Median, Min, NegativeExpr, NotExpr, NthValue, + IsNullExpr, LastValue, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, Variance, VariancePop, WindowShift, @@ -318,8 +318,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::ApproxPercentileContWithWeight } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ApproxMedian - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Median } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::FirstValueAgg } else if aggr_expr.downcast_ref::().is_some() { From 6e8e18dce73f234421499e5d4c40e9070f6870d4 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 24 May 2024 20:45:29 +0800 Subject: [PATCH 3/7] rm old median Signed-off-by: jayzhan211 --- .../physical-expr/src/aggregate/median.rs | 297 ------------------ datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - .../physical-plan/src/aggregates/mod.rs | 27 +- 4 files changed, 20 insertions(+), 306 deletions(-) delete mode 100644 datafusion/physical-expr/src/aggregate/median.rs diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs deleted file mode 100644 index ee0fce3fabe7..000000000000 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ /dev/null @@ -1,297 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! # Median - -use crate::aggregate::utils::{down_cast_any_ref, Hashable}; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; -use arrow_array::{downcast_integer, ArrowNativeTypeOp, ArrowNumericType}; -use arrow_buffer::ArrowNativeType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Formatter; -use std::sync::Arc; - -/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a -/// lot of memory because all values need to be stored in memory before a result can be -/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more -/// efficient solution. -/// -/// If using the distinct variation, the memory usage will be similarly high if the -/// cardinality is high as it stores all distinct values in memory before computing the -/// result, but if cardinality is low then memory usage will also be lower. -#[derive(Debug)] -pub struct Median { - name: String, - expr: Arc, - data_type: DataType, - distinct: bool, -} - -impl Median { - /// Create a new MEDIAN aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - distinct: bool, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - distinct, - } - } -} - -impl AggregateExpr for Median { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - use arrow_array::types::*; - macro_rules! helper { - ($t:ty, $dt:expr) => { - if self.distinct { - Ok(Box::new(DistinctMedianAccumulator::<$t> { - data_type: $dt.clone(), - distinct_values: HashSet::new(), - })) - } else { - Ok(Box::new(MedianAccumulator::<$t> { - data_type: $dt.clone(), - all_values: vec![], - })) - } - }; - } - let dt = &self.data_type; - downcast_integer! { - dt => (helper, dt), - DataType::Float16 => helper!(Float16Type, dt), - DataType::Float32 => helper!(Float32Type, dt), - DataType::Float64 => helper!(Float64Type, dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), - _ => Err(DataFusionError::NotImplemented(format!( - "MedianAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - //Intermediate state is a list of the elements we have collected so far - let field = Field::new("item", self.data_type.clone(), true); - let data_type = DataType::List(Arc::new(field)); - let state_name = if self.distinct { - "distinct_median" - } else { - "median" - }; - - Ok(vec![Field::new( - format_state_name(&self.name, state_name), - data_type, - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Median { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - && self.distinct == x.distinct - }) - .unwrap_or(false) - } -} - -/// The median accumulator accumulates the raw input values -/// as `ScalarValue`s -/// -/// The intermediate state is represented as a List of scalar values updated by -/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values -/// in the final evaluation step so that we avoid expensive conversions and -/// allocations during `update_batch`. -struct MedianAccumulator { - data_type: DataType, - all_values: Vec, -} - -impl std::fmt::Debug for MedianAccumulator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "MedianAccumulator({})", self.data_type) - } -} - -impl Accumulator for MedianAccumulator { - fn state(&mut self) -> Result> { - let all_values = self - .all_values - .iter() - .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) - .collect::>>()?; - - let arr = ScalarValue::new_list(&all_values, &self.data_type); - Ok(vec![ScalarValue::List(arr)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - self.all_values.reserve(values.len() - values.null_count()); - self.all_values.extend(values.iter().flatten()); - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let array = states[0].as_list::(); - for v in array.iter().flatten() { - self.update_batch(&[v])? - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.all_values); - let median = calculate_median::(d); - ScalarValue::new_primitive::(median, &self.data_type) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.all_values.capacity() * std::mem::size_of::() - } -} - -/// The distinct median accumulator accumulates the raw input values -/// as `ScalarValue`s -/// -/// The intermediate state is represented as a List of scalar values updated by -/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values -/// in the final evaluation step so that we avoid expensive conversions and -/// allocations during `update_batch`. -struct DistinctMedianAccumulator { - data_type: DataType, - distinct_values: HashSet>, -} - -impl std::fmt::Debug for DistinctMedianAccumulator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "DistinctMedianAccumulator({})", self.data_type) - } -} - -impl Accumulator for DistinctMedianAccumulator { - fn state(&mut self) -> Result> { - let all_values = self - .distinct_values - .iter() - .map(|x| ScalarValue::new_primitive::(Some(x.0), &self.data_type)) - .collect::>>()?; - - let arr = ScalarValue::new_list(&all_values, &self.data_type); - Ok(vec![ScalarValue::List(arr)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let array = values[0].as_primitive::(); - match array.nulls().filter(|x| x.null_count() > 0) { - Some(n) => { - for idx in n.valid_indices() { - self.distinct_values.insert(Hashable(array.value(idx))); - } - } - None => array.values().iter().for_each(|x| { - self.distinct_values.insert(Hashable(*x)); - }), - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let array = states[0].as_list::(); - for v in array.iter().flatten() { - self.update_batch(&[v])? - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.distinct_values) - .into_iter() - .map(|v| v.0) - .collect::>(); - let median = calculate_median::(d); - ScalarValue::new_primitive::(median, &self.data_type) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.distinct_values.capacity() * std::mem::size_of::() - } -} - -fn calculate_median( - mut values: Vec, -) -> Option { - let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); - - let len = values.len(); - if len == 0 { - None - } else if len % 2 == 0 { - let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); - let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); - let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); - Some(median) - } else { - let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp); - Some(*median) - } -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 93ecf0655e51..039c8814e987 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -39,7 +39,6 @@ pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; pub(crate) mod grouping; -pub(crate) mod median; pub(crate) mod nth_value; pub(crate) mod string_agg; #[macro_use] diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 980297b8b433..a7921800fccd 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,7 +53,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::grouping::Grouping; -pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 21608db40d56..cf31c2990b7d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1203,8 +1203,9 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_functions_aggregate::median::median_udaf; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, + lit, ApproxDistinct, Count, FirstValue, LastValue, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{reverse_order_bys, PhysicalSortExpr}; @@ -1773,6 +1774,22 @@ mod tests { check_grouping_sets(input, true).await } + // Median(a) + fn test_median_agg_expr(schema: &Schema) -> Result> { + let args = vec![col("a", schema)?]; + let fun = median_udaf(); + datafusion_physical_expr_common::aggregate::create_aggregate_expr( + &fun, + &args, + &[], + &[], + schema, + "MEDIAN(a)", + false, + false, + ) + } + #[tokio::test] async fn test_oom() -> Result<()> { let input: Arc = Arc::new(TestYieldingExec::new(true)); @@ -1792,12 +1809,8 @@ mod tests { }; // something that allocates within the aggregator - let aggregates_v0: Vec> = vec![Arc::new(Median::new( - col("a", &input_schema)?, - "MEDIAN(a)".to_string(), - DataType::UInt32, - false, - ))]; + let aggregates_v0: Vec> = + vec![test_median_agg_expr(&input_schema)?]; // use slow-path in `hash.rs` let aggregates_v1: Vec> = From b4ff8ae6290355777cc01c4e17a5f788f157bdbb Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 24 May 2024 21:55:46 +0800 Subject: [PATCH 4/7] introduce numeric signature Signed-off-by: jayzhan211 --- datafusion/expr/src/function.rs | 3 ++ datafusion/expr/src/signature.rs | 13 ++++++++ .../expr/src/type_coercion/aggregates.rs | 5 +-- .../expr/src/type_coercion/functions.rs | 32 +++++++++++++++++++ datafusion/expr/src/udaf.rs | 28 ++-------------- datafusion/functions-aggregate/src/median.rs | 19 ++--------- .../physical-expr-common/src/aggregate/mod.rs | 2 ++ 7 files changed, 58 insertions(+), 44 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 714cfa1af671..eb748ed2711a 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -82,6 +82,9 @@ pub struct AccumulatorArgs<'a> { /// The number of arguments the aggregate function takes. pub args_num: usize, + + /// The name of the expression + pub name: &'a str, } /// [`StateFieldsArgs`] contains information about the fields that an diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 63b030f0b748..a2b5c0f93b57 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -119,6 +119,8 @@ pub enum TypeSignature { OneOf(Vec), /// Specifies Signatures for array functions ArraySignature(ArrayFunctionSignature), + /// Fixed number of arguments of numeric types + Numeric(usize), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -178,6 +180,9 @@ impl TypeSignature { .collect::>() .join(", ")] } + TypeSignature::Numeric(num) => { + vec![format!("Numeric({})", num)] + } TypeSignature::Exact(types) => { vec![Self::join_types(types, ", ")] } @@ -259,6 +264,14 @@ impl Signature { volatility, } } + + pub fn numeric(num: usize, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::Numeric(num), + volatility, + } + } + /// An arbitrary number of arguments of any type. pub fn variadic_any(volatility: Volatility) -> Self { Self { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 8c365fcfda77..ce4a2a709842 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -355,8 +355,9 @@ pub fn check_arg_count( ); } } - TypeSignature::UserDefined => { - // User-defined functions are not validated here + TypeSignature::UserDefined | TypeSignature::Numeric(_) => { + // User-defined signature is validated in `coerce_types` + // Numreic signature is validated in `get_valid_types` } _ => { return internal_err!( diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 583d75e1ccfc..b41ec109103d 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -322,6 +322,38 @@ fn get_valid_types( .iter() .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), + TypeSignature::Numeric(number) => { + if *number < 1 { + return plan_err!( + "The signature expected at least one argument but received {}", + current_types.len() + ); + } + if *number != current_types.len() { + return plan_err!( + "The signature expected {} arguments but received {}", + number, + current_types.len() + ); + } + + let mut valid_type = current_types.first().unwrap().clone(); + for t in current_types.iter().skip(1) { + if let Some(coerced_type) = + comparison_binary_numeric_coercion(&valid_type, t) + { + valid_type = coerced_type; + } else { + return plan_err!( + "{} and {} are not coercible to a common numeric type", + valid_type, + t + ); + } + } + + vec![vec![valid_type; *number]] + } TypeSignature::Uniform(number, valid_types) => valid_types .iter() .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b8df94c950fe..b620a897bcc9 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -193,9 +193,8 @@ impl AggregateUDF { self.inner.create_groups_accumulator() } - /// See [`AggregateUDFImpl::coerce_types`] for more details. - pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - self.inner.coerce_types(arg_types) + pub fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) } /// Do the function rewrite @@ -391,29 +390,6 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::NotSupported } - - /// Coerce arguments of a function call to types that the function can evaluate. - /// - /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most - /// UDFs should return one of the other variants of `TypeSignature` which handle common - /// cases - /// - /// See the [type coercion module](crate::type_coercion) - /// documentation for more details on type coercion - /// - /// For example, if your function requires a floating point arguments, but the user calls - /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` - /// to ensure the argument was cast to `1::double` - /// - /// # Parameters - /// * `arg_types`: The argument types of the arguments this function with - /// - /// # Return value - /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call - /// arguments to these specific types. - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { - not_impl_err!("Function {} does not implement coerce_types", self.name()) - } } pub enum ReversedUDAF { diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index fa5a302833bb..6758c6432e0e 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -32,7 +32,7 @@ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::ArrowNativeType; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, @@ -72,7 +72,7 @@ impl Median { pub fn new() -> Self { Self { aliases: vec!["median".to_string()], - signature: Signature::user_defined(Volatility::Immutable), // signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) + signature: Signature::numeric(1, Volatility::Immutable), } } } @@ -90,18 +90,6 @@ impl AggregateUDFImpl for Median { &self.signature } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return exec_err!("Median takes exactly one argument"); - } - - if arg_types[0].is_numeric() { - Ok(vec![arg_types[0].clone()]) - } else { - exec_err!("Median only accepts numeric types") - } - } - fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(arg_types[0].clone()) } @@ -149,8 +137,7 @@ impl AggregateUDFImpl for Median { DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), _ => Err(DataFusionError::NotImplemented(format!( "MedianAccumulator not supported for {} with {}", - // acc_args.name, - "name", + acc_args.name, dt, ))), } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 4ef0d58046f8..4e9414bc5a11 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -219,6 +219,7 @@ impl AggregateExpr for AggregateFunctionExpr { is_distinct: self.is_distinct, input_type: &self.input_type, args_num: self.args.len(), + name: &self.name, }; self.fun.accumulator(acc_args) @@ -292,6 +293,7 @@ impl AggregateExpr for AggregateFunctionExpr { is_distinct: self.is_distinct, input_type: &self.input_type, args_num: self.args.len(), + name: &self.name, }; self.fun.groups_accumulator_supported(args) } From f055acbe9281a0b2d1c5cbd2a342c2430a95d6f5 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 25 May 2024 20:54:17 +0800 Subject: [PATCH 5/7] address comment Signed-off-by: jayzhan211 --- datafusion/expr/src/signature.rs | 3 ++- datafusion/functions-aggregate/src/median.rs | 8 ++++++ .../sqllogictest/test_files/aggregate.slt | 27 +++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index a2b5c0f93b57..3486c38c2b22 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -119,7 +119,8 @@ pub enum TypeSignature { OneOf(Vec), /// Specifies Signatures for array functions ArraySignature(ArrayFunctionSignature), - /// Fixed number of arguments of numeric types + /// Fixed number of arguments of numeric types. + /// See https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric to know which type is considered numeric Numeric(usize), } diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 6758c6432e0e..b3fb05d7fcf7 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -48,6 +48,14 @@ make_udaf_expr_and_func!( median_udaf ); +/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a +/// lot of memory because all values need to be stored in memory before a result can be +/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more +/// efficient solution. +/// +/// If using the distinct variation, the memory usage will be similarly high if the +/// cardinality is high as it stores all distinct values in memory before computing the +/// result, but if cardinality is low then memory usage will also be lower. pub struct Median { signature: Signature, aliases: Vec, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index c2478e543735..2a220ea0a89d 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -871,6 +871,33 @@ select median(distinct c) from t; statement ok drop table t; +# optimize distinct median to group by +statement ok +create table t(c int) as values (1), (1), (1), (1), (2), (2), (3), (3); + +query TT +explain select median(distinct c) from t; +---- +logical_plan +01)Projection: MEDIAN(alias1) AS MEDIAN(DISTINCT t.c) +02)--Aggregate: groupBy=[[]], aggr=[[MEDIAN(alias1)]] +03)----Aggregate: groupBy=[[t.c AS alias1]], aggr=[[]] +04)------TableScan: t projection=[c] +physical_plan +01)ProjectionExec: expr=[MEDIAN(alias1)@0 as MEDIAN(DISTINCT t.c)] +02)--AggregateExec: mode=Final, gby=[], aggr=[MEDIAN(alias1)] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[MEDIAN(alias1)] +05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------RepartitionExec: partitioning=Hash([alias1@0], 4), input_partitions=4 +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------AggregateExec: mode=Partial, gby=[c@0 as alias1], aggr=[] +10)------------------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +drop table t; + # median_multi # test case for https://github.com/apache/datafusion/issues/3105 # has an intermediate grouping From 2b2582ae86eac7b99dc3e4a46de2463b73d7464d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 25 May 2024 21:08:07 +0800 Subject: [PATCH 6/7] fix doc Signed-off-by: jayzhan211 --- datafusion/expr/src/signature.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 3486c38c2b22..33f643eb2dc2 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -120,7 +120,7 @@ pub enum TypeSignature { /// Specifies Signatures for array functions ArraySignature(ArrayFunctionSignature), /// Fixed number of arguments of numeric types. - /// See https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric to know which type is considered numeric + /// See to know which type is considered numeric Numeric(usize), } From 371290d14267a4bb292cd98662aa1ce6b33d3fde Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 26 May 2024 07:44:13 +0800 Subject: [PATCH 7/7] add proto roundtrip Signed-off-by: jayzhan211 --- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6e819ef5bf46..d83d6cd1c297 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -33,6 +33,7 @@ use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; use datafusion::functions_aggregate::expr_fn::first_value; +use datafusion::functions_aggregate::median::median; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -624,6 +625,7 @@ async fn roundtrip_expr_api() -> Result<()> { first_value(vec![lit(1)], false, None, None, None), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), + median(lit(2)), ]; // ensure expressions created with the expr api can be round tripped