From 62d2ea4d03ded1a5cb5e2238646ed69e7cdb9b82 Mon Sep 17 00:00:00 2001 From: Yijun Zhao Date: Mon, 23 Oct 2023 19:13:40 +0800 Subject: [PATCH] add quantile_tdigest_weighted agg func --- .../aggregate-quantile-tdigest-weighted.md | 63 ++++ .../aggregate-quantile-tdigest.md | 2 +- .../10-aggregate-functions/index.md | 69 ++-- .../aggregates/aggregate_quantile_tdigest.rs | 33 +- .../aggregate_quantile_tdigest_weighted.rs | 306 ++++++++++++++++++ .../functions/src/aggregates/aggregator.rs | 10 + src/query/functions/src/aggregates/mod.rs | 1 + .../functions/tests/it/aggregates/agg.rs | 16 + .../tests/it/aggregates/testdata/agg.txt | 22 ++ .../02_0000_function_aggregate_mix | 10 + 10 files changed, 482 insertions(+), 50 deletions(-) create mode 100644 docs/doc/15-sql-functions/10-aggregate-functions/aggregate-quantile-tdigest-weighted.md create mode 100644 src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs diff --git a/docs/doc/15-sql-functions/10-aggregate-functions/aggregate-quantile-tdigest-weighted.md b/docs/doc/15-sql-functions/10-aggregate-functions/aggregate-quantile-tdigest-weighted.md new file mode 100644 index 0000000000000..11951d5d9a5aa --- /dev/null +++ b/docs/doc/15-sql-functions/10-aggregate-functions/aggregate-quantile-tdigest-weighted.md @@ -0,0 +1,63 @@ +--- +title: QUANTILE_TDIGEST_WEIGHTED +--- +import FunctionDescription from '@site/src/components/FunctionDescription'; + + + +Computes an approximate quantile of a numeric data sequence using the [t-digest](https://github.com/tdunning/t-digest/blob/master/docs/t-digest-paper/histo.pdf) algorithm. +This function takes into account the weight of each sequence member. Memory consumption is **log(n)**, where **n** is a number of values. + +:::caution +NULL values are not included in the calculation. +::: + +## Syntax + +```sql +QUANTILE_TDIGEST_WEIGHTED([, , ...])(, ) +``` + +## Arguments + +| Arguments | Description | +|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------| +| `` | A level of quantile represents a constant floating-point number ranging from 0 to 1. It is recommended to use a level value in the range of [0.01, 0.99]. | +| `` | Any numerical expression | +| `` | Any unsigned integer expression. Weight is a number of value occurrences. | + +## Return Type + +Returns either a Float64 value or an array of Float64 values, depending on the number of quantile levels specified. + +## Example + +```sql +-- Create a table and insert sample data +CREATE TABLE sales_data ( + id INT, + sales_person_id INT, + sales_amount FLOAT +); + +INSERT INTO sales_data (id, sales_person_id, sales_amount) +VALUES (1, 1, 5000), + (2, 2, 5500), + (3, 3, 6000), + (4, 4, 6500), + (5, 5, 7000); + +SELECT QUANTILE_TDIGEST_WEIGHTED(0.5)(sales_amount, 1) AS median_sales_amount +FROM sales_data; + +median_sales_amount| +-------------------+ + 6000.0| + +SELECT QUANTILE_TDIGEST_WEIGHTED(0.5, 0.8)(sales_amount, 1) +FROM sales_data; + +quantile_tdigest_weighted(0.5, 0.8)(sales_amount)| +-------------------------------------------------+ +[6000.0,7000.0] | +``` \ No newline at end of file diff --git a/docs/doc/15-sql-functions/10-aggregate-functions/aggregate-quantile-tdigest.md b/docs/doc/15-sql-functions/10-aggregate-functions/aggregate-quantile-tdigest.md index 84531b71c6c6e..1e0ff7799a12d 100644 --- a/docs/doc/15-sql-functions/10-aggregate-functions/aggregate-quantile-tdigest.md +++ b/docs/doc/15-sql-functions/10-aggregate-functions/aggregate-quantile-tdigest.md @@ -7,7 +7,7 @@ import FunctionDescription from '@site/src/components/FunctionDescription'; Computes an approximate quantile of a numeric data sequence using the [t-digest](https://github.com/tdunning/t-digest/blob/master/docs/t-digest-paper/histo.pdf) algorithm. -:::note +:::caution NULL values are not included in the calculation. ::: diff --git a/docs/doc/15-sql-functions/10-aggregate-functions/index.md b/docs/doc/15-sql-functions/10-aggregate-functions/index.md index e277b4d57be89..fd419f5f130f9 100644 --- a/docs/doc/15-sql-functions/10-aggregate-functions/index.md +++ b/docs/doc/15-sql-functions/10-aggregate-functions/index.md @@ -6,37 +6,38 @@ Aggregate functions are essential tools in SQL that allow you to perform calcula These functions help you extract and summarize data from databases to gain valuable insights. -| Function Name | What It Does | -|-------------------------------------------------------------|---------------------------------------------------------------------------| -| [ANY](aggregate-any.md) | Checks if any row meets the specified condition | -| [APPROX_COUNT_DISTINCT](aggregate-approx-count-distinct.md) | Estimates the number of distinct values with HyperLogLog | -| [ARG_MAX](aggregate-arg-max.md) | Finds the arg value for the maximum val value | -| [ARG_MIN](aggregate-arg-min.md) | Finds the arg value for the minimum val value | -| [AVG_IF](aggregate-avg-if.md) | Calculates the average for rows meeting a condition | -| [ARRAY_AGG](aggregate-array-agg.md) | Converts all the values of a column to an Array | -| [AVG](aggregate-avg.md) | Calculates the average value of a specific column | -| [COUNT_DISTINCT](aggregate-count-distinct.md) | Counts the number of distinct values in a column | -| [COUNT_IF](aggregate-count-if.md) | Counts rows meeting a specified condition | -| [COUNT](aggregate-count.md) | Counts the number of rows that meet certain criteria | -| [COVAR_POP](aggregate-covar-pop.md) | Returns the population covariance of a set of number pairs | -| [COVAR_SAMP](aggregate-covar-samp.md) | Returns the sample covariance of a set of number pairs | -| [GROUP_ARRAY_MOVING_AVG](aggregate-group-array-moving-avg.md) | Returns an array with elements calculates the moving average of input values | -| [GROUP_ARRAY_MOVING_SUM](aggregate-group-array-moving-sum.md) | Returns an array with elements calculates the moving sum of input values | -| [KURTOSIS](aggregate-kurtosis.md) | Calculates the excess kurtosis of a set of values | -| [MAX_IF](aggregate-max-if.md) | Finds the maximum value for rows meeting a condition | -| [MAX](aggregate-max.md) | Finds the largest value in a specific column | -| [MEDIAN](aggregate-median.md) | Calculates the median value of a specific column | -| [MEDIAN_TDIGEST](aggregate-median-tdigest.md) | Calculates the median value of a specific column using t-digest algorithm | -| [MIN_IF](aggregate-min-if.md) | Finds the minimum value for rows meeting a condition | -| [MIN](aggregate-min.md) | Finds the smallest value in a specific column | -| [QUANTILE_CONT](aggregate-quantile-cont.md) | Calculates the interpolated quantile for a specific column | -| [QUANTILE_DISC](aggregate-quantile-disc.md) | Calculates the quantile for a specific column | -| [QUANTILE_TDIGEST](aggregate-quantile-tdigest.md) | Calculates the quantile using t-digest algorithm | -| [RETENTION](aggregate-retention.md) | Calculates retention for a set of events | -| [SKEWNESS](aggregate-skewness.md) | Calculates the skewness of a set of values | -| [STDDEV_POP](aggregate-stddev-pop.md) | Calculates the population standard deviation of a column | -| [STDDEV_SAMP](aggregate-stddev-samp.md) | Calculates the sample standard deviation of a column | -| [STRING_AGG](aggregate-string-agg.md) | Converts all the non-NULL values to String, separated by the delimiter | -| [SUM_IF](aggregate-sum-if.md) | Adds up the values meeting a condition of a specific column | -| [SUM](aggregate-sum.md) | Adds up the values of a specific column | -| [WINDOW_FUNNEL](aggregate-windowfunnel.md) | Analyzes user behavior in a time-ordered sequence of events | +| Function Name | What It Does | +|---------------------------------------------------------------------|------------------------------------------------------------------------------| +| [ANY](aggregate-any.md) | Checks if any row meets the specified condition | +| [APPROX_COUNT_DISTINCT](aggregate-approx-count-distinct.md) | Estimates the number of distinct values with HyperLogLog | +| [ARG_MAX](aggregate-arg-max.md) | Finds the arg value for the maximum val value | +| [ARG_MIN](aggregate-arg-min.md) | Finds the arg value for the minimum val value | +| [AVG_IF](aggregate-avg-if.md) | Calculates the average for rows meeting a condition | +| [ARRAY_AGG](aggregate-array-agg.md) | Converts all the values of a column to an Array | +| [AVG](aggregate-avg.md) | Calculates the average value of a specific column | +| [COUNT_DISTINCT](aggregate-count-distinct.md) | Counts the number of distinct values in a column | +| [COUNT_IF](aggregate-count-if.md) | Counts rows meeting a specified condition | +| [COUNT](aggregate-count.md) | Counts the number of rows that meet certain criteria | +| [COVAR_POP](aggregate-covar-pop.md) | Returns the population covariance of a set of number pairs | +| [COVAR_SAMP](aggregate-covar-samp.md) | Returns the sample covariance of a set of number pairs | +| [GROUP_ARRAY_MOVING_AVG](aggregate-group-array-moving-avg.md) | Returns an array with elements calculates the moving average of input values | +| [GROUP_ARRAY_MOVING_SUM](aggregate-group-array-moving-sum.md) | Returns an array with elements calculates the moving sum of input values | +| [KURTOSIS](aggregate-kurtosis.md) | Calculates the excess kurtosis of a set of values | +| [MAX_IF](aggregate-max-if.md) | Finds the maximum value for rows meeting a condition | +| [MAX](aggregate-max.md) | Finds the largest value in a specific column | +| [MEDIAN](aggregate-median.md) | Calculates the median value of a specific column | +| [MEDIAN_TDIGEST](aggregate-median-tdigest.md) | Calculates the median value of a specific column using t-digest algorithm | +| [MIN_IF](aggregate-min-if.md) | Finds the minimum value for rows meeting a condition | +| [MIN](aggregate-min.md) | Finds the smallest value in a specific column | +| [QUANTILE_CONT](aggregate-quantile-cont.md) | Calculates the interpolated quantile for a specific column | +| [QUANTILE_DISC](aggregate-quantile-disc.md) | Calculates the quantile for a specific column | +| [QUANTILE_TDIGEST](aggregate-quantile-tdigest.md) | Calculates the quantile using t-digest algorithm | +| [QUANTILE_TDIGEST_WEIGHTED](aggregate-quantile-tdigest-weighted.md) | Calculates the quantile with weighted using t-digest algorithm | +| [RETENTION](aggregate-retention.md) | Calculates retention for a set of events | +| [SKEWNESS](aggregate-skewness.md) | Calculates the skewness of a set of values | +| [STDDEV_POP](aggregate-stddev-pop.md) | Calculates the population standard deviation of a column | +| [STDDEV_SAMP](aggregate-stddev-samp.md) | Calculates the sample standard deviation of a column | +| [STRING_AGG](aggregate-string-agg.md) | Converts all the non-NULL values to String, separated by the delimiter | +| [SUM_IF](aggregate-sum-if.md) | Adds up the values meeting a condition of a specific column | +| [SUM](aggregate-sum.md) | Adds up the values of a specific column | +| [WINDOW_FUNNEL](aggregate-windowfunnel.md) | Analyzes user behavior in a time-ordered sequence of events | diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs index b767315ca10c0..2e531d3f25f7f 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs @@ -47,11 +47,11 @@ use crate::aggregates::AggregateFunctionRef; use crate::aggregates::StateAddr; use crate::BUILTIN_FUNCTIONS; -const MEDIAN: u8 = 0; -const QUANTILE: u8 = 1; +pub(crate) const MEDIAN: u8 = 0; +pub(crate) const QUANTILE: u8 = 1; #[derive(Serialize, Deserialize)] -struct QuantileTDigestState { +pub(crate) struct QuantileTDigestState { epsilon: u32, max_centroids: usize, @@ -67,7 +67,7 @@ struct QuantileTDigestState { } impl QuantileTDigestState { - fn new() -> Self { + pub(crate) fn new() -> Self { Self { epsilon: 100u32, max_centroids: 2048, @@ -82,17 +82,17 @@ impl QuantileTDigestState { } } - fn add(&mut self, other: f64) { + pub(crate) fn add(&mut self, other: f64, weight: Option) { if self.unmerged_weights.len() + self.weights.len() >= self.max_centroids - 1 { self.compress(); } - self.unmerged_weights.push(1f64); + self.unmerged_weights.push(weight.unwrap_or(1) as f64); self.unmerged_means.push(other); self.unmerged_total_weight += 1f64; } - fn merge(&mut self, rhs: &mut Self) -> Result<()> { + pub(crate) fn merge(&mut self, rhs: &mut Self) -> Result<()> { if rhs.len() == 0 { return Ok(()); } @@ -107,7 +107,11 @@ impl QuantileTDigestState { Ok(()) } - fn merge_result(&mut self, builder: &mut ColumnBuilder, levels: Vec) -> Result<()> { + pub(crate) fn merge_result( + &mut self, + builder: &mut ColumnBuilder, + levels: Vec, + ) -> Result<()> { if levels.len() > 1 { let builder = match builder { ColumnBuilder::Array(box b) => b, @@ -126,7 +130,7 @@ impl QuantileTDigestState { Ok(()) } - fn quantile(&mut self, level: f64) -> f64 { + pub(crate) fn quantile(&mut self, level: f64) -> f64 { self.compress(); if self.weights.is_empty() { return 0f64; @@ -317,13 +321,13 @@ where T: Number + AsPrimitive Some(bitmap) => { for (value, is_valid) in column.iter().zip(bitmap.iter()) { if is_valid { - state.add(value.as_()); + state.add(value.as_(), None); } } } None => { for value in column.iter() { - state.add(value.as_()); + state.add(value.as_(), None); } } } @@ -335,7 +339,7 @@ where T: Number + AsPrimitive let v = NumberType::::index_column(&column, row); if let Some(v) = v { let state = place.get::(); - state.add(v.as_()) + state.add(v.as_(), None) } Ok(()) } @@ -350,8 +354,7 @@ where T: Number + AsPrimitive column.iter().zip(places.iter()).for_each(|(v, place)| { let addr = place.next(offset); let state = addr.get::(); - let v = v.as_(); - state.add(v) + state.add(v.as_(), None) }); Ok(()) } @@ -489,7 +492,7 @@ pub fn try_create_aggregate_quantile_tdigest_function( } _ => Err(ErrorCode::BadDataValueType(format!( - "{} does not support type '{:?}'", + "{} just support numeric type, but got '{:?}'", display_name, arguments[0] ))), }) diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs new file mode 100644 index 0000000000000..a946d1807ded7 --- /dev/null +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs @@ -0,0 +1,306 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::alloc::Layout; +use std::fmt::Display; +use std::fmt::Formatter; +use std::marker::PhantomData; +use std::sync::Arc; + +use common_arrow::arrow::bitmap::Bitmap; +use common_exception::ErrorCode; +use common_exception::Result; +use common_expression::type_check::check_number; +use common_expression::types::number::*; +use common_expression::types::*; +use common_expression::with_number_mapped_type; +use common_expression::with_unsigned_number_mapped_type; +use common_expression::Column; +use common_expression::ColumnBuilder; +use common_expression::Expr; +use common_expression::FunctionContext; +use common_expression::Scalar; +use num_traits::AsPrimitive; + +use super::deserialize_state; +use super::serialize_state; +use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; +use crate::aggregates::aggregate_quantile_tdigest::QuantileTDigestState; +use crate::aggregates::aggregate_quantile_tdigest::MEDIAN; +use crate::aggregates::aggregate_quantile_tdigest::QUANTILE; +use crate::aggregates::assert_binary_arguments; +use crate::aggregates::assert_params; +use crate::aggregates::AggregateFunction; +use crate::aggregates::AggregateFunctionRef; +use crate::aggregates::StateAddr; +use crate::BUILTIN_FUNCTIONS; + +#[derive(Clone)] +pub struct AggregateQuantileTDigestWeightedFunction { + display_name: String, + return_type: DataType, + levels: Vec, + _arguments: Vec, + _t0: PhantomData, + _t1: PhantomData, +} + +impl Display for AggregateQuantileTDigestWeightedFunction +where + T0: Number + AsPrimitive, + T1: Number + AsPrimitive, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.display_name) + } +} + +impl AggregateFunction for AggregateQuantileTDigestWeightedFunction +where + T0: Number + AsPrimitive, + T1: Number + AsPrimitive, +{ + fn name(&self) -> &str { + "AggregateQuantileDiscFunction" + } + fn return_type(&self) -> Result { + Ok(self.return_type.clone()) + } + fn init_state(&self, place: StateAddr) { + place.write(QuantileTDigestState::new) + } + fn state_layout(&self) -> Layout { + Layout::new::() + } + fn accumulate( + &self, + place: StateAddr, + columns: &[Column], + validity: Option<&Bitmap>, + _input_rows: usize, + ) -> Result<()> { + let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); + let weighted = NumberType::::try_downcast_column(&columns[1]).unwrap(); + let state = place.get::(); + match validity { + Some(bitmap) => { + for ((value, weight), is_valid) in + column.iter().zip(weighted.iter()).zip(bitmap.iter()) + { + if is_valid { + state.add(value.as_(), Some(weight.as_())); + } + } + } + None => { + for (value, weight) in column.iter().zip(weighted.iter()) { + state.add(value.as_(), Some(weight.as_())); + } + } + } + + Ok(()) + } + fn accumulate_row(&self, place: StateAddr, columns: &[Column], row: usize) -> Result<()> { + let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); + let weighted = NumberType::::try_downcast_column(&columns[1]).unwrap(); + let value = unsafe { column.get_unchecked(row) }; + let weight = unsafe { weighted.get_unchecked(row) }; + + let state = place.get::(); + state.add(value.as_(), Some(weight.as_())); + Ok(()) + } + fn accumulate_keys( + &self, + places: &[StateAddr], + offset: usize, + columns: &[Column], + _input_rows: usize, + ) -> Result<()> { + let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); + let weighted = NumberType::::try_downcast_column(&columns[1]).unwrap(); + column + .iter() + .zip(weighted.iter()) + .zip(places.iter()) + .for_each(|((value, weight), place)| { + let addr = place.next(offset); + let state = addr.get::(); + state.add(value.as_(), Some(weight.as_())) + }); + Ok(()) + } + fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + let state = place.get::(); + serialize_state(writer, state) + } + + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + let state = place.get::(); + let mut rhs: QuantileTDigestState = deserialize_state(reader)?; + state.merge(&mut rhs) + } + + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + let state = place.get::(); + let other = rhs.get::(); + state.merge(other) + } + + fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + let state = place.get::(); + state.merge_result(builder, self.levels.clone()) + } + + fn need_manual_drop_state(&self) -> bool { + true + } + + unsafe fn drop_state(&self, place: StateAddr) { + let state = place.get::(); + std::ptr::drop_in_place(state); + } +} + +impl AggregateQuantileTDigestWeightedFunction +where + T0: Number + AsPrimitive, + T1: Number + AsPrimitive, +{ + fn try_create( + display_name: &str, + return_type: DataType, + params: Vec, + arguments: Vec, + ) -> Result> { + let levels = if params.len() == 1 { + let level: F64 = check_number( + None, + &FunctionContext::default(), + &Expr::::Cast { + span: None, + is_try: false, + expr: Box::new(Expr::Constant { + span: None, + scalar: params[0].clone(), + data_type: params[0].as_ref().infer_data_type(), + }), + dest_type: DataType::Number(NumberDataType::Float64), + }, + &BUILTIN_FUNCTIONS, + )?; + let level = level.0; + if !(0.0..=1.0).contains(&level) { + return Err(ErrorCode::BadDataValueType(format!( + "level range between [0, 1], got: {:?}", + level + ))); + } + vec![level] + } else if params.is_empty() { + vec![0.5f64] + } else { + let mut levels = Vec::with_capacity(params.len()); + for param in params { + let level: F64 = check_number( + None, + &FunctionContext::default(), + &Expr::::Cast { + span: None, + is_try: false, + expr: Box::new(Expr::Constant { + span: None, + scalar: param.clone(), + data_type: param.as_ref().infer_data_type(), + }), + dest_type: DataType::Number(NumberDataType::Float64), + }, + &BUILTIN_FUNCTIONS, + )?; + let level = level.0; + if !(0.0..=1.0).contains(&level) { + return Err(ErrorCode::BadDataValueType(format!( + "level range between [0, 1], got: {:?} in levels", + level + ))); + } + levels.push(level); + } + levels + }; + let func = AggregateQuantileTDigestWeightedFunction:: { + display_name: display_name.to_string(), + return_type, + levels, + _arguments: arguments, + _t0: PhantomData, + _t1: PhantomData, + }; + Ok(Arc::new(func)) + } +} + +pub fn try_create_aggregate_quantile_tdigest_weighted_function( + display_name: &str, + params: Vec, + arguments: Vec, +) -> Result { + if TYPE == MEDIAN { + assert_params(display_name, params.len(), 0)?; + } + + assert_binary_arguments(display_name, arguments.len())?; + with_number_mapped_type!(|NUM_TYPE_0| match &arguments[0] { + DataType::Number(NumberDataType::NUM_TYPE_0) => { + let return_type = if params.len() > 1 { + DataType::Array(Box::new(DataType::Number(NumberDataType::Float64))) + } else { + DataType::Number(NumberDataType::Float64) + }; + + with_unsigned_number_mapped_type!(|NUM_TYPE_1| match &arguments[1] { + DataType::Number(NumberDataType::NUM_TYPE_1) => { + AggregateQuantileTDigestWeightedFunction::::try_create( + display_name, + return_type, + params, + arguments, + ) + } + _ => Err(ErrorCode::BadDataValueType(format!( + "weight just support unsigned integer type, but got '{:?}'", + arguments[1] + ))), + }) + } + + _ => Err(ErrorCode::BadDataValueType(format!( + "{} just support numeric type, but got '{:?}'", + display_name, arguments[0] + ))), + }) +} + +pub fn aggregate_quantile_tdigest_weighted_function_desc() -> AggregateFunctionDescription { + AggregateFunctionDescription::creator(Box::new( + try_create_aggregate_quantile_tdigest_weighted_function::, + )) +} + +pub fn aggregate_median_tdigest_weighted_function_desc() -> AggregateFunctionDescription { + AggregateFunctionDescription::creator(Box::new( + try_create_aggregate_quantile_tdigest_weighted_function::, + )) +} diff --git a/src/query/functions/src/aggregates/aggregator.rs b/src/query/functions/src/aggregates/aggregator.rs index a9e453afa8252..06b0d3ce10095 100644 --- a/src/query/functions/src/aggregates/aggregator.rs +++ b/src/query/functions/src/aggregates/aggregator.rs @@ -46,6 +46,8 @@ use crate::aggregates::aggregate_quantile_cont::aggregate_quantile_cont_function use crate::aggregates::aggregate_quantile_disc::aggregate_quantile_disc_function_desc; use crate::aggregates::aggregate_quantile_tdigest::aggregate_median_tdigest_function_desc; use crate::aggregates::aggregate_quantile_tdigest::aggregate_quantile_tdigest_function_desc; +use crate::aggregates::aggregate_quantile_tdigest_weighted::aggregate_median_tdigest_weighted_function_desc; +use crate::aggregates::aggregate_quantile_tdigest_weighted::aggregate_quantile_tdigest_weighted_function_desc; use crate::aggregates::aggregate_retention::aggregate_retention_function_desc; use crate::aggregates::aggregate_skewness::aggregate_skewness_function_desc; use crate::aggregates::aggregate_string_agg::aggregate_string_agg_function_desc; @@ -80,8 +82,16 @@ impl Aggregators { "quantile_tdigest", aggregate_quantile_tdigest_function_desc(), ); + factory.register( + "quantile_tdigest_weighted", + aggregate_quantile_tdigest_weighted_function_desc(), + ); factory.register("median", aggregate_median_function_desc()); factory.register("median_tdigest", aggregate_median_tdigest_function_desc()); + factory.register( + "median_tdigest_weighted", + aggregate_median_tdigest_weighted_function_desc(), + ); factory.register("window_funnel", aggregate_window_funnel_function_desc()); factory.register( "approx_count_distinct", diff --git a/src/query/functions/src/aggregates/mod.rs b/src/query/functions/src/aggregates/mod.rs index 96c46c392d0f0..2e405415b4c54 100644 --- a/src/query/functions/src/aggregates/mod.rs +++ b/src/query/functions/src/aggregates/mod.rs @@ -36,6 +36,7 @@ mod aggregate_null_result; mod aggregate_quantile_cont; mod aggregate_quantile_disc; mod aggregate_quantile_tdigest; +mod aggregate_quantile_tdigest_weighted; mod aggregate_retention; mod aggregate_scalar_state; mod aggregate_skewness; diff --git a/src/query/functions/tests/it/aggregates/agg.rs b/src/query/functions/tests/it/aggregates/agg.rs index 5d82a090d4377..5a8443cd0c8d7 100644 --- a/src/query/functions/tests/it/aggregates/agg.rs +++ b/src/query/functions/tests/it/aggregates/agg.rs @@ -59,6 +59,7 @@ fn test_agg() { test_agg_quantile_disc(file, eval_aggr); test_agg_quantile_cont(file, eval_aggr); test_agg_quantile_tdigest(file, eval_aggr); + test_agg_quantile_tdigest_weighted(file, eval_aggr); test_agg_median(file, eval_aggr); test_agg_median_tdigest(file, eval_aggr); test_agg_array_agg(file, eval_aggr); @@ -625,6 +626,21 @@ fn test_agg_quantile_tdigest(file: &mut impl Write, simulator: impl AggregationS ); } +fn test_agg_quantile_tdigest_weighted(file: &mut impl Write, simulator: impl AggregationSimulator) { + run_agg_ast( + file, + "quantile_tdigest_weighted(0.8)(a, b)", + get_example().as_slice(), + simulator, + ); + run_agg_ast( + file, + "quantile_tdigest_weighted(0.8)(x_null, b)", + get_example().as_slice(), + simulator, + ); +} + fn test_agg_group_array_moving_avg(file: &mut impl Write, simulator: impl AggregationSimulator) { run_agg_ast( file, diff --git a/src/query/functions/tests/it/aggregates/testdata/agg.txt b/src/query/functions/tests/it/aggregates/testdata/agg.txt index 46f098ea951ac..93816e4d7cbe7 100644 --- a/src/query/functions/tests/it/aggregates/testdata/agg.txt +++ b/src/query/functions/tests/it/aggregates/testdata/agg.txt @@ -847,6 +847,28 @@ evaluation (internal): +--------+-------------------------------------------------------------------------+ +ast: quantile_tdigest_weighted(0.8)(a, b) +evaluation (internal): ++--------+-----------------------------------------------------------------+ +| Column | Data | ++--------+-----------------------------------------------------------------+ +| a | Int64([4, 3, 2, 1]) | +| b | UInt64([1, 2, 3, 4]) | +| Output | NullableColumn { column: Float64([0]), validity: [0b_______1] } | ++--------+-----------------------------------------------------------------+ + + +ast: quantile_tdigest_weighted(0.8)(x_null, b) +evaluation (internal): ++--------+-------------------------------------------------------------------------+ +| Column | Data | ++--------+-------------------------------------------------------------------------+ +| b | UInt64([1, 2, 3, 4]) | +| x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } | +| Output | NullableColumn { column: Float64([0]), validity: [0b_______1] } | ++--------+-------------------------------------------------------------------------+ + + ast: median(a) evaluation (internal): +--------+-------------------------------------------------------------------+ diff --git a/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix b/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix index 62d74fdf39dd6..379cf5086fe4e 100644 --- a/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix +++ b/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix @@ -250,6 +250,16 @@ SELECT quantile_tdigest(0, 0.5, 0.6, 1)(number) from numbers_mt(10000) ---- [0.0,4999.5,5999.5,9999.0] +query F +SELECT quantile_tdigest_weighted(0.6)(number, 1) from numbers_mt(10000) +---- +5999.5 + +query T +SELECT quantile_tdigest_weighted(0, 0.5, 0.6, 1)(number, 1) from numbers_mt(10000) +---- +[0.0,4999.5,5999.5,9999.0] + query T SELECT list(number) from numbers_mt(10) ----