Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Median to functions-aggregate and Introduce Numeric signature #10644

Merged
merged 7 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, count, max, median, min, stddev, utils::COUNT_STAR_EXPANSION,
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null, sum};
use datafusion_functions_aggregate::expr_fn::median;

use async_trait::async_trait;

Expand Down
9 changes: 1 addition & 8 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ pub enum AggregateFunction {
Max,
/// Average
Avg,
/// Median
Median,
/// Approximate distinct function
ApproxDistinct,
/// Aggregation into an array
Expand Down Expand Up @@ -114,7 +112,6 @@ impl AggregateFunction {
Min => "MIN",
Max => "MAX",
Avg => "AVG",
Median => "MEDIAN",
ApproxDistinct => "APPROX_DISTINCT",
ArrayAgg => "ARRAY_AGG",
FirstValue => "FIRST_VALUE",
Expand Down Expand Up @@ -168,7 +165,6 @@ impl FromStr for AggregateFunction {
"count" => AggregateFunction::Count,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"median" => AggregateFunction::Median,
"min" => AggregateFunction::Min,
"sum" => AggregateFunction::Sum,
"array_agg" => AggregateFunction::ArrayAgg,
Expand Down Expand Up @@ -275,9 +271,7 @@ impl AggregateFunction {
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian | AggregateFunction::Median => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::FirstValue
| AggregateFunction::LastValue
Expand Down Expand Up @@ -335,7 +329,6 @@ impl AggregateFunction {
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
| AggregateFunction::Median
| AggregateFunction::ApproxMedian
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => {
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,6 @@ pub fn approx_distinct(expr: Expr) -> Expr {
))
}

/// Calculate the median for `expr`.
pub fn median(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Median,
vec![expr],
false,
None,
None,
None,
))
}

/// Calculate an approximation of the median for `expr`.
pub fn approx_median(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ pub struct AccumulatorArgs<'a> {

/// The number of arguments the aggregate function takes.
pub args_num: usize,

/// The name of the expression
pub name: &'a str,
}

/// [`StateFieldsArgs`] contains information about the fields that an
Expand Down
14 changes: 14 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ pub enum TypeSignature {
OneOf(Vec<TypeSignature>),
/// Specifies Signatures for array functions
ArraySignature(ArrayFunctionSignature),
/// Fixed number of arguments of numeric types.
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
Numeric(usize),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -178,6 +181,9 @@ impl TypeSignature {
.collect::<Vec<String>>()
.join(", ")]
}
TypeSignature::Numeric(num) => {
vec![format!("Numeric({})", num)]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
}
Expand Down Expand Up @@ -259,6 +265,14 @@ impl Signature {
volatility,
}
}

pub fn numeric(num: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Numeric(num),
volatility,
}
}

/// An arbitrary number of arguments of any type.
pub fn variadic_any(volatility: Volatility) -> Self {
Self {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl TreeNode for Expr {
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
fun,
new_args,
false,
distinct,
new_filter,
new_order_by,
null_treatment,
Expand Down
10 changes: 7 additions & 3 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::Median
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => Ok(input_types.to_vec()),
AggregateFunction::FirstValue | AggregateFunction::LastValue => {
Ok(input_types.to_vec())
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
Expand Down Expand Up @@ -355,6 +355,10 @@ pub fn check_arg_count(
);
}
}
TypeSignature::UserDefined | TypeSignature::Numeric(_) => {
// User-defined signature is validated in `coerce_types`
// Numreic signature is validated in `get_valid_types`
}
_ => {
return internal_err!(
"Aggregate functions do not support this {signature:?}"
Expand Down
32 changes: 32 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,38 @@ fn get_valid_types(
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::Numeric(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
}

let mut valid_type = current_types.first().unwrap().clone();
for t in current_types.iter().skip(1) {
if let Some(coerced_type) =
comparison_binary_numeric_coercion(&valid_type, t)
{
valid_type = coerced_type;
} else {
return plan_err!(
"{} and {} are not coercible to a common numeric type",
valid_type,
t
);
}
}

vec![vec![valid_type; *number]]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ impl AggregateUDF {
self.inner.create_groups_accumulator()
}

pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
pub fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
}

/// Do the function rewrite
///
/// See [`AggregateUDFImpl::simplify`] for more details.
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ path = "src/lib.rs"

[dependencies]
arrow = { workspace = true }
arrow-schema = { workspace = true }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for downcast_integer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the issue that not re-exported in arrow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, some macro has inevitable arrow-schema dependency apache/arrow-rs#5676.

datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
Expand Down
3 changes: 3 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub mod macros;

pub mod covariance;
pub mod first_last;
pub mod median;

use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
Expand All @@ -68,6 +69,7 @@ use std::sync::Arc;
pub mod expr_fn {
pub use super::covariance::covar_samp;
pub use super::first_last::first_value;
pub use super::median::median;
}

/// Registers all enabled packages with a [`FunctionRegistry`]
Expand All @@ -76,6 +78,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
first_last::first_value_udaf(),
covariance::covar_samp_udaf(),
covariance::covar_pop_udaf(),
median::median_udaf(),
];

functions.into_iter().try_for_each(|udf| {
Expand Down
Loading