Skip to content

Commit

Permalink
Moving min and max to new API and removing from protobuf
Browse files Browse the repository at this point in the history
  • Loading branch information
edmondop committed Jun 19, 2024
1 parent 4109f58 commit 552f52b
Show file tree
Hide file tree
Showing 22 changed files with 955 additions and 199 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
avg,utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum};
use datafusion_functions_aggregate::expr_fn::{count,max, median,min, stddev, sum};

use async_trait::async_trait;

Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col,
array_agg, avg, cast, col, exists, expr, in_subquery, lit, out_ref_col,
placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{count, sum};
use datafusion_functions_aggregate::expr_fn::{count, max, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
7 changes: 4 additions & 3 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::aggregates::coerce_types;
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::{
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
BuiltInWindowFunction, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
Expand Down Expand Up @@ -360,14 +361,14 @@ fn get_random_function(
window_fn_map.insert(
"min",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
WindowFunctionDefinition::AggregateUDF(min_udaf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"max",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![arg.clone()],
),
);
Expand Down
25 changes: 0 additions & 25 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ use strum_macros::EnumIter;
// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// Minimum
Min,
/// Maximum
Max,
/// Average
Avg,
/// Aggregation into an array
Expand All @@ -57,8 +53,6 @@ impl AggregateFunction {
pub fn name(&self) -> &str {
use AggregateFunction::*;
match self {
Min => "MIN",
Max => "MAX",
Avg => "AVG",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Expand All @@ -84,9 +78,7 @@ impl FromStr for AggregateFunction {
"avg" => AggregateFunction::Avg,
"bool_and" => AggregateFunction::BoolAnd,
"bool_or" => AggregateFunction::BoolOr,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
// statistical
Expand Down Expand Up @@ -123,11 +115,6 @@ impl AggregateFunction {
})?;

match self {
AggregateFunction::Max | AggregateFunction::Min => {
// For min and max agg function, the returned type is same as input type.
// The coerced_data_types is same with input_types.
Ok(coerced_data_types[0].clone())
}
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Ok(DataType::Boolean)
}
Expand Down Expand Up @@ -167,18 +154,6 @@ impl AggregateFunction {
AggregateFunction::Grouping | AggregateFunction::ArrayAgg => {
Signature::any(1, Volatility::Immutable)
}
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
.chain(NUMERICS.iter())
.chain(TIMESTAMPS.iter())
.chain(DATES.iter())
.chain(TIMES.iter())
.chain(BINARYS.iter())
.cloned()
.collect::<Vec<_>>();
Signature::uniform(1, valid, Volatility::Immutable)
}
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable)
}
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2255,18 +2255,6 @@ mod test {

#[test]
fn test_find_df_window_function() {
assert_eq!(
find_df_window_func("max"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Max
))
);
assert_eq!(
find_df_window_func("min"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Min
))
);
assert_eq!(
find_df_window_func("avg"),
Some(WindowFunctionDefinition::AggregateFunction(
Expand Down
24 changes: 0 additions & 24 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,30 +145,6 @@ pub fn not(expr: Expr) -> Expr {
expr.not()
}

/// Create an expression to represent the min() aggregate function
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Min,
vec![expr],
false,
None,
None,
None,
))
}

/// Create an expression to represent the max() aggregate function
pub fn max(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Max,
vec![expr],
false,
None,
None,
None,
))
}

/// Create an expression to represent the array_agg() aggregate function
pub fn array_agg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ mod test {
use arrow::datatypes::{DataType, Field, Schema};

use crate::{
avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast,
avg, cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast,
LogicalPlanBuilder,
};

Expand Down
33 changes: 1 addition & 32 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use std::ops::Deref;

use crate::{AggregateFunction, Signature, TypeSignature};

Expand Down Expand Up @@ -96,11 +95,6 @@ pub fn coerce_types(

match agg_fun {
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
// min and max support the dictionary data type
// unpack the dictionary to get the value
get_min_max_result_type(input_types)
}
AggregateFunction::Avg => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval
Expand Down Expand Up @@ -208,22 +202,6 @@ pub fn check_arg_count(
Ok(())
}

fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
// make sure that the input types only has one element.
assert_eq!(input_types.len(), 1);
// min and max support the dictionary data type
// unpack the dictionary to get the value
match &input_types[0] {
DataType::Dictionary(_, dict_value_type) => {
// TODO add checker, if the value type is complex data type
Ok(vec![dict_value_type.deref().clone()])
}
// TODO add checker for datatype which min and max supported
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
_ => Ok(input_types.to_vec()),
}
}

/// function return type of a sum
pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
Expand Down Expand Up @@ -380,13 +358,6 @@ mod tests {

#[test]
fn test_aggregate_coerce_types() {
// test input args with error number input types
let fun = AggregateFunction::Min;
let input_types = vec![DataType::Int64, DataType::Int32];
let signature = fun.signature();
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());

let fun = AggregateFunction::Avg;
// test input args is invalid data type for avg
let input_types = vec![DataType::Utf8];
Expand All @@ -397,12 +368,10 @@ mod tests {
result.unwrap_err().strip_backtrace()
);

// test count, array_agg, approx_distinct, min, max.
// test count, array_agg, approx_distinct.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::ArrayAgg,
AggregateFunction::Min,
AggregateFunction::Max,
];
let input_types = vec![
vec![DataType::Int32],
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1268,23 +1268,23 @@ mod tests {
#[test]
fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
WindowFunctionDefinition::AggregateUDF(min_udaf()),
vec![col("name")],
vec![],
vec![],
Expand Down Expand Up @@ -1371,7 +1371,7 @@ mod tests {
fn test_find_sort_exprs() -> Result<()> {
let exprs = &[
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
vec![],
vec![
Expand Down
4 changes: 4 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub mod covariance;
pub mod first_last;
pub mod hyperloglog;
pub mod median;
pub mod min_max;
pub mod regr;
pub mod stddev;
pub mod sum;
Expand Down Expand Up @@ -96,6 +97,8 @@ pub mod expr_fn {
pub use super::first_last::first_value;
pub use super::first_last::last_value;
pub use super::median::median;
pub use super::min_max::max;
pub use super::min_max::min;
pub use super::regr::regr_avgx;
pub use super::regr::regr_avgy;
pub use super::regr::regr_count;
Expand All @@ -120,6 +123,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
covariance::covar_samp_udaf(),
sum::sum_udaf(),
covariance::covar_pop_udaf(),
min_max::max_udaf(),
median::median_udaf(),
count::count_udaf(),
regr::regr_slope_udaf(),
Expand Down
Loading

0 comments on commit 552f52b

Please sign in to comment.