Skip to content

Commit

Permalink
Implement quantile aggregate functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
2010YOUY01 committed Aug 18, 2023
1 parent 90484bb commit db202fc
Show file tree
Hide file tree
Showing 15 changed files with 1,118 additions and 438 deletions.
25 changes: 22 additions & 3 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ pub enum AggregateFunction {
RegrSYY,
/// Sum of products of pairs of numbers
RegrSXY,
/// Continuous percentile
QuantileCont,
/// Discrete percentile
QuantileDisc,
/// Approximate continuous percentile function
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
Expand Down Expand Up @@ -132,6 +136,8 @@ impl AggregateFunction {
RegrSXX => "REGR_SXX",
RegrSYY => "REGR_SYY",
RegrSXY => "REGR_SXY",
QuantileCont => "QUANTILE_CONT",
QuantileDisc => "QUANTILE_DISC",
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
ApproxMedian => "APPROX_MEDIAN",
Expand Down Expand Up @@ -191,6 +197,8 @@ impl FromStr for AggregateFunction {
"regr_sxx" => AggregateFunction::RegrSXX,
"regr_syy" => AggregateFunction::RegrSYY,
"regr_sxy" => AggregateFunction::RegrSXY,
"quantile_cont" => AggregateFunction::QuantileCont,
"quantile_disc" => AggregateFunction::QuantileDisc,
// approximate
"approx_distinct" => AggregateFunction::ApproxDistinct,
"approx_median" => AggregateFunction::ApproxMedian,
Expand Down Expand Up @@ -293,9 +301,10 @@ impl AggregateFunction {
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian | AggregateFunction::Median => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian
| AggregateFunction::Median
| AggregateFunction::QuantileCont
| AggregateFunction::QuantileDisc => Ok(coerced_data_types[0].clone()),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::FirstValue | AggregateFunction::LastValue => {
Ok(coerced_data_types[0].clone())
Expand Down Expand Up @@ -380,6 +389,16 @@ impl AggregateFunction {
| AggregateFunction::RegrSXY => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::QuantileCont | AggregateFunction::QuantileDisc => {
// signature: quantile_*(NUMERICS, float64)
Signature::one_of(
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
.collect(),
Volatility::Immutable,
)
}
AggregateFunction::ApproxPercentileCont => {
// Accept any numeric value paired with a float64 percentile
let with_tdigest_size = NUMERICS.iter().map(|t| {
Expand Down
15 changes: 15 additions & 0 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,21 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::QuantileCont | AggregateFunction::QuantileDisc => {
let valid_arg0_types = [NUMERICS.to_vec(), vec![DataType::Null]].concat();
let valid_arg1_types = NUMERICS;
let input_types_valid = // number of input already checked before
valid_arg0_types.contains(&input_types[0]) && valid_arg1_types.contains(&input_types[1]);
if !input_types_valid {
return plan_err!(
"The function {:?} does not support inputs of type {:?}, {:?}.",
agg_fun,
input_types[0],
input_types[1]
);
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxPercentileCont => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return plan_err!(
Expand Down
32 changes: 1 addition & 31 deletions datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::aggregate::tdigest::TryIntoF64;
use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use crate::aggregate::utils::down_cast_any_ref;
use crate::aggregate::utils::{down_cast_any_ref, validate_input_percentile_expr};
use crate::expressions::{format_state_name, Literal};
use crate::{AggregateExpr, PhysicalExpr};
use arrow::{
Expand All @@ -27,7 +27,6 @@ use arrow::{
},
datatypes::{DataType, Field},
};
use datafusion_common::plan_err;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::{downcast_value, ScalarValue};
Expand Down Expand Up @@ -131,35 +130,6 @@ impl PartialEq for ApproxPercentileCont {
}
}

fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
// Extract the desired percentile literal
let lit = expr
.as_any()
.downcast_ref::<Literal>()
.ok_or_else(|| {
DataFusionError::Internal(
"desired percentile argument must be float literal".to_string(),
)
})?
.value();
let percentile = match lit {
ScalarValue::Float32(Some(q)) => *q as f64,
ScalarValue::Float64(Some(q)) => *q,
got => return Err(DataFusionError::NotImplemented(format!(
"Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
got.get_datatype()
)))
};

// Ensure the percentile is between 0 and 1.
if !(0.0..=1.0).contains(&percentile) {
return plan_err!(
"Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
);
}
Ok(percentile)
}

fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
// Extract the desired percentile literal
let lit = expr
Expand Down
21 changes: 21 additions & 0 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
//! * Signature: see `Signature`
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.

use crate::aggregate::percentile::PercentileInterpolationType;
use crate::aggregate::regr::RegrType;
use crate::{expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr};
use arrow::datatypes::Schema;
Expand Down Expand Up @@ -329,6 +330,26 @@ pub fn create_aggregate_expr(
fun
)));
}
(AggregateFunction::QuantileCont, false) => Arc::new(expressions::Quantile::new(
name,
PercentileInterpolationType::Continuous,
input_phy_exprs[0].clone(),
input_phy_exprs[1].clone(),
rt_type,
)?),
(AggregateFunction::QuantileDisc, false) => Arc::new(expressions::Quantile::new(
name,
PercentileInterpolationType::Discrete,
input_phy_exprs[0].clone(),
input_phy_exprs[1].clone(),
rt_type,
)?),
(AggregateFunction::QuantileDisc | AggregateFunction::QuantileCont, true) => {
return Err(DataFusionError::NotImplemented(format!(
"{}(DISTINCT) aggregations are not available",
fun
)));
}
(AggregateFunction::ApproxPercentileCont, false) => {
if input_phy_exprs.len() == 2 {
Arc::new(expressions::ApproxPercentileCont::new(
Expand Down
Loading

0 comments on commit db202fc

Please sign in to comment.