Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ pub fn approx_percentile_cont(
#[user_doc(
doc_section(label = "Approximate Functions"),
description = "Returns the approximate percentile of input values using the t-digest algorithm.",
syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)",
syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
sql_example = r#"```sql
> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
+------------------------------------------------------------------+
| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
+------------------------------------------------------------------+
| 65.0 |
+------------------------------------------------------------------+
> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
+-----------------------------------------------------------------------+
| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
Expand Down Expand Up @@ -313,7 +319,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
}
if arg_types.len() == 3 && !arg_types[2].is_integer() {
return plan_err!(
"approx_percentile_cont requires integer max_size input types"
"approx_percentile_cont requires integer centroids input types"
);
}
Ok(arg_types[0].clone())
Expand Down Expand Up @@ -360,6 +366,11 @@ impl ApproxPercentileAccumulator {
}
}

// public for approx_percentile_cont_with_weight
pub(crate) fn max_size(&self) -> usize {
self.digest.max_size()
}

// public for approx_percentile_cont_with_weight
pub fn merge_digests(&mut self, digests: &[TDigest]) {
let digests = digests.iter().chain(std::iter::once(&self.digest));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,66 @@ use arrow::datatypes::FieldRef;
use arrow::{array::ArrayRef, datatypes::DataType};
use datafusion_common::ScalarValue;
use datafusion_common::{not_impl_err, plan_err, Result};
use datafusion_expr::expr::{AggregateFunction, Sort};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
use datafusion_expr::Volatility::Immutable;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature,
};
use datafusion_functions_aggregate_common::tdigest::{
Centroid, TDigest, DEFAULT_MAX_SIZE,
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
};
use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
use datafusion_macros::user_doc;

use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};

make_udaf_expr_and_func!(
create_func!(
ApproxPercentileContWithWeight,
approx_percentile_cont_with_weight,
expression weight percentile,
"Computes the approximate percentile continuous with weight of a set of numbers",
approx_percentile_cont_with_weight_udaf
);

/// Computes the approximate percentile continuous with weight of a set of numbers
pub fn approx_percentile_cont_with_weight(
order_by: Sort,
Copy link
Contributor

@jcsherin jcsherin Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.rs/datafusion/latest/datafusion/functions_aggregate/approx_percentile_cont_with_weight/fn.approx_percentile_cont_with_weight.html

The first argument has changed from Expr in the current API to Sort.

pub fn approx_percentile_cont_with_weight(
    expression: Expr,
    weight: Expr,
    percentile: Expr,
) -> Expr

Shouldn't the order_by be an Expr?

Copy link
Contributor

@jcsherin jcsherin Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind 👍, I see that you have made the type narrower.

#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Sort {
/// The expression to sort on
pub expr: Expr,
/// The direction of the sort
pub asc: bool,
/// Whether to put Nulls before all other data values
pub nulls_first: bool,
}

weight: Expr,
percentile: Expr,
centroids: Option<Expr>,
) -> Expr {
let expr = order_by.expr.clone();

let args = if let Some(centroids) = centroids {
vec![expr, weight, percentile, centroids]
} else {
vec![expr, weight, percentile]
};

Expr::AggregateFunction(AggregateFunction::new_udf(
approx_percentile_cont_with_weight_udaf(),
args,
false,
None,
vec![order_by],
None,
))
}

/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
#[user_doc(
doc_section(label = "Approximate Functions"),
description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)",
syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
sql_example = r#"```sql
> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
+---------------------------------------------------------------------------------------------+
| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
+---------------------------------------------------------------------------------------------+
| 78.5 |
+---------------------------------------------------------------------------------------------+
> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
+--------------------------------------------------------------------------------------------------+
| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
+--------------------------------------------------------------------------------------------------+
| 78.5 |
+--------------------------------------------------------------------------------------------------+
```"#,
standard_argument(name = "expression", prefix = "The"),
argument(
Expand All @@ -67,6 +94,10 @@ make_udaf_expr_and_func!(
argument(
name = "percentile",
description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
),
argument(
name = "centroids",
description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
)
)]
pub struct ApproxPercentileContWithWeight {
Expand All @@ -91,21 +122,26 @@ impl Default for ApproxPercentileContWithWeight {
impl ApproxPercentileContWithWeight {
/// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
pub fn new() -> Self {
let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
// Accept any numeric value paired with weight and float64 percentile
for num in NUMERICS {
variants.push(TypeSignature::Exact(vec![
num.clone(),
num.clone(),
DataType::Float64,
]));
// Additionally accept an integer number of centroids for T-Digest
for int in INTEGERS {
variants.push(TypeSignature::Exact(vec![
num.clone(),
num.clone(),
DataType::Float64,
int.clone(),
]));
}
}
Self {
signature: Signature::one_of(
// Accept any numeric value paired with a float64 percentile
NUMERICS
.iter()
.map(|t| {
TypeSignature::Exact(vec![
t.clone(),
t.clone(),
DataType::Float64,
])
})
.collect(),
Immutable,
),
signature: Signature::one_of(variants, Immutable),
approx_percentile_cont: ApproxPercentileCont::new(),
}
}
Expand Down Expand Up @@ -138,6 +174,11 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
if arg_types[2] != DataType::Float64 {
return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
}
if arg_types.len() == 4 && !arg_types[3].is_integer() {
return plan_err!(
"approx_percentile_cont_with_weight requires integer centroids input types"
);
}
Ok(arg_types[0].clone())
}

Expand All @@ -148,17 +189,25 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
);
}

if acc_args.exprs.len() != 3 {
if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
return plan_err!(
"approx_percentile_cont_with_weight requires three arguments: value, weight, percentile"
"approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
);
}

let sub_args = AccumulatorArgs {
exprs: &[
Arc::clone(&acc_args.exprs[0]),
Arc::clone(&acc_args.exprs[2]),
],
exprs: if acc_args.exprs.len() == 4 {
&[
Arc::clone(&acc_args.exprs[0]), // value
Arc::clone(&acc_args.exprs[2]), // percentile
Arc::clone(&acc_args.exprs[3]), // centroids
]
} else {
&[
Arc::clone(&acc_args.exprs[0]), // value
Arc::clone(&acc_args.exprs[2]), // percentile
]
},
..acc_args
};
let approx_percentile_cont_accumulator =
Expand Down Expand Up @@ -244,7 +293,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
let mut digests: Vec<TDigest> = vec![];
for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
digests.push(TDigest::new_with_centroid(
DEFAULT_MAX_SIZE,
self.approx_percentile_cont_accumulator.max_size(),
Centroid::new(*mean, *weight),
))
}
Expand Down
13 changes: 12 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,18 @@ async fn roundtrip_expr_api() -> Result<()> {
approx_median(lit(2)),
approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None),
approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))),
approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)),
approx_percentile_cont_with_weight(
lit(2).sort(true, false),
lit(1),
lit(0.5),
None,
),
approx_percentile_cont_with_weight(
lit(2).sort(true, false),
lit(1),
lit(0.5),
Some(lit(50)),
),
grouping(lit(1)),
bit_and(lit(2)),
bit_or(lit(2)),
Expand Down
10 changes: 10 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,16 @@ c 123
d 124
e 115

# approx_percentile_cont_with_weight with centroids
query TI
SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this test pass if the third argument is 100?

pub const DEFAULT_MAX_SIZE: usize = 100;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100 and 200 makes no difference in terms of this test result. I changed it here just to make sure the function with new arg can compile. I can add more tests here tho

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Maybe keep the original test intact and then add this test with the centroids argument as a new one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

----
a 74
b 68
c 123
d 124
e 115

# csv_query_sum_crossjoin
query TTI
SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1
Expand Down
42 changes: 21 additions & 21 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,27 +285,27 @@ select log(-1), log(0), sqrt(-1);

## Aggregate Functions

| Syntax | Description |
| ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- |
| avg(expr) | Сalculates the average value for `expr`. |
| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. |
| approx_median(expr) | Calculates an approximation of the median for `expr`. |
| approx_percentile_cont(expr, percentile) | Calculates an approximation of the specified `percentile` for `expr`. |
| approx_percentile_cont_with_weight(expr, weight_expr, percentile) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. |
| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. |
| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. |
| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. |
| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. |
| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. |
| count(expr) | Returns the number of rows for `expr`. |
| count_distinct | Creates an expression to represent the count(distinct) aggregate function |
| cube(exprs) | Creates a grouping set for all combination of `exprs` |
| grouping_set(exprs) | Create a grouping set. |
| max(expr) | Finds the maximum value of `expr`. |
| median(expr) | Сalculates the median of `expr`. |
| min(expr) | Finds the minimum value of `expr`. |
| rollup(exprs) | Creates a grouping set for rollup sets. |
| sum(expr) | Сalculates the sum of `expr`. |
| Syntax | Description |
| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
| avg(expr) | Сalculates the average value for `expr`. |
| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. |
| approx_median(expr) | Calculates an approximation of the median for `expr`. |
| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). |
| approx_percentile_cont_with_weight(expr, weight_expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. Optional `centroids` parameter controls accuracy (default: 100). |
| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. |
| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. |
| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. |
| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. |
| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. |
| count(expr) | Returns the number of rows for `expr`. |
| count_distinct | Creates an expression to represent the count(distinct) aggregate function |
| cube(exprs) | Creates a grouping set for all combination of `exprs` |
| grouping_set(exprs) | Create a grouping set. |
| max(expr) | Finds the maximum value of `expr`. |
| median(expr) | Сalculates the median of `expr`. |
| min(expr) | Finds the minimum value of `expr`. |
| rollup(exprs) | Creates a grouping set for rollup sets. |
| sum(expr) | Сalculates the sum of `expr`. |

## Aggregate Function Builder

Expand Down
Loading