diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index d6c8eabb459e6..a4e8332626b00 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::mem::{size_of, size_of_val}; use std::sync::Arc; @@ -26,27 +26,27 @@ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ array::{Array, ArrayRef, AsArray}, datatypes::{ - ArrowNativeType, DataType, Decimal32Type, Decimal64Type, Decimal128Type, - Decimal256Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, + ArrowNativeType, DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type, }, }; use arrow::array::ArrowNativeTypeOp; +use datafusion_common::internal_err; +use datafusion_common::types::{NativeType, logical_float64}; +use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator; use crate::min_max::{max_udaf, min_udaf}; use datafusion_common::{ - DataFusionError, Result, ScalarValue, assert_eq_or_internal_err, - internal_datafusion_err, plan_err, utils::take_function_args, + Result, ScalarValue, internal_datafusion_err, utils::take_function_args, }; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, - Volatility, + Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature, + TypeSignatureClass, Volatility, }; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_expr::{ - expr::{AggregateFunction, Cast, Sort}, + expr::{AggregateFunction, Sort}, function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, simplify::SimplifyInfo, }; @@ -121,21 +121,12 @@ An alternate syntax is also supported: /// If using the distinct variation, the memory usage will be similarly high if the /// cardinality is high as it stores all distinct values in memory before computing the /// result, but if cardinality is low then memory usage will also be lower. -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct PercentileCont { signature: Signature, aliases: Vec, } -impl Debug for PercentileCont { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("PercentileCont") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for PercentileCont { fn default() -> Self { Self::new() @@ -144,76 +135,27 @@ impl Default for PercentileCont { impl PercentileCont { pub fn new() -> Self { - let mut variants = Vec::with_capacity(NUMERICS.len()); - // Accept any numeric value paired with a float64 percentile - for num in NUMERICS { - variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); - } Self { - signature: Signature::one_of(variants, Volatility::Immutable) - .with_parameter_names(vec!["expr".to_string(), "percentile".to_string()]) - .expect("valid parameter names for percentile_cont"), + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec!["expr", "percentile"]) + .unwrap(), aliases: vec![String::from("quantile_cont")], } } - - fn create_accumulator(&self, args: &AccumulatorArgs) -> Result> { - let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?; - - let is_descending = args - .order_bys - .first() - .map(|sort_expr| sort_expr.options.descending) - .unwrap_or(false); - - let percentile = if is_descending { - 1.0 - percentile - } else { - percentile - }; - - macro_rules! helper { - ($t:ty, $dt:expr) => { - if args.is_distinct { - Ok(Box::new(DistinctPercentileContAccumulator::<$t> { - data_type: $dt.clone(), - distinct_values: GenericDistinctBuffer::new($dt), - percentile, - })) - } else { - Ok(Box::new(PercentileContAccumulator::<$t> { - data_type: $dt.clone(), - all_values: vec![], - percentile, - })) - } - }; - } - - let input_dt = args.exprs[0].data_type(args.schema)?; - match input_dt { - // For integer types, use Float64 internally since percentile_cont returns Float64 - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => helper!(Float64Type, DataType::Float64), - DataType::Float16 => helper!(Float16Type, input_dt), - DataType::Float32 => helper!(Float32Type, input_dt), - DataType::Float64 => helper!(Float64Type, input_dt), - DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt), - DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt), - _ => Err(DataFusionError::NotImplemented(format!( - "PercentileContAccumulator not supported for {} with {}", - args.name, input_dt, - ))), - } - } } impl AggregateUDFImpl for PercentileCont { @@ -234,53 +176,26 @@ impl AggregateUDFImpl for PercentileCont { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("percentile_cont requires numeric input types"); - } - // PERCENTILE_CONT performs linear interpolation and should return a float type - // For integer inputs, return Float64 (matching PostgreSQL/DuckDB behavior) - // For float inputs, preserve the float type match &arg_types[0] { - DataType::Float16 | DataType::Float32 | DataType::Float64 => { - Ok(arg_types[0].clone()) - } - DataType::Decimal32(_, _) - | DataType::Decimal64(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) => Ok(arg_types[0].clone()), - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => Ok(DataType::Float64), - // Shouldn't happen due to signature check, but just in case - dt => plan_err!( - "percentile_cont does not support input type {}, must be numeric", - dt - ), + DataType::Null => Ok(DataType::Float64), + dt => Ok(dt.clone()), } } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - //Intermediate state is a list of the elements we have collected so far let input_type = args.input_fields[0].data_type().clone(); - // For integer types, we store as Float64 internally - let storage_type = match &input_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DataType::Float64, - _ => input_type, - }; + if input_type.is_null() { + return Ok(vec![ + Field::new( + format_state_name(args.name, self.name()), + DataType::Null, + true, + ) + .into(), + ]); + } - let field = Field::new_list_field(storage_type, true); + let field = Field::new_list_field(input_type, true); let state_name = if args.is_distinct { "distinct_percentile_cont" } else { @@ -297,70 +212,65 @@ impl AggregateUDFImpl for PercentileCont { ]) } - fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - self.create_accumulator(&acc_args) + fn accumulator(&self, args: AccumulatorArgs) -> Result> { + let percentile = get_percentile(&args)?; + + let input_dt = args.expr_fields[0].data_type(); + if input_dt.is_null() { + return Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None)))); + } + + if args.is_distinct { + match input_dt { + DataType::Float16 => Ok(Box::new(DistinctPercentileContAccumulator::< + Float16Type, + >::new(percentile))), + DataType::Float32 => Ok(Box::new(DistinctPercentileContAccumulator::< + Float32Type, + >::new(percentile))), + DataType::Float64 => Ok(Box::new(DistinctPercentileContAccumulator::< + Float64Type, + >::new(percentile))), + dt => internal_err!("Unsupported datatype for percentile cont: {dt}"), + } + } else { + match input_dt { + DataType::Float16 => Ok(Box::new( + PercentileContAccumulator::::new(percentile), + )), + DataType::Float32 => Ok(Box::new( + PercentileContAccumulator::::new(percentile), + )), + DataType::Float64 => Ok(Box::new( + PercentileContAccumulator::::new(percentile), + )), + dt => internal_err!("Unsupported datatype for percentile cont: {dt}"), + } + } } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - !args.is_distinct + !args.is_distinct && !args.expr_fields[0].data_type().is_null() } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - let num_args = args.exprs.len(); - assert_eq_or_internal_err!( - num_args, - 2, - "percentile_cont should have 2 args, but found num args:{}", - num_args - ); - - let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?; + let percentile = get_percentile(&args)?; - let is_descending = args - .order_bys - .first() - .map(|sort_expr| sort_expr.options.descending) - .unwrap_or(false); - - let percentile = if is_descending { - 1.0 - percentile - } else { - percentile - }; - - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(PercentileContGroupsAccumulator::<$t>::new( - $dt, percentile, - ))) - }; - } - - let input_dt = args.exprs[0].data_type(args.schema)?; + let input_dt = args.expr_fields[0].data_type(); match input_dt { - // For integer types, use Float64 internally since percentile_cont returns Float64 - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => helper!(Float64Type, DataType::Float64), - DataType::Float16 => helper!(Float16Type, input_dt), - DataType::Float32 => helper!(Float32Type, input_dt), - DataType::Float64 => helper!(Float64Type, input_dt), - DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt), - DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt), - _ => Err(DataFusionError::NotImplemented(format!( - "PercentileContGroupsAccumulator not supported for {} with {}", - args.name, input_dt, - ))), + DataType::Float16 => Ok(Box::new(PercentileContGroupsAccumulator::< + Float16Type, + >::new(percentile))), + DataType::Float32 => Ok(Box::new(PercentileContGroupsAccumulator::< + Float32Type, + >::new(percentile))), + DataType::Float64 => Ok(Box::new(PercentileContGroupsAccumulator::< + Float64Type, + >::new(percentile))), + dt => internal_err!("Unsupported datatype for percentile cont: {dt}"), } } @@ -379,21 +289,42 @@ impl AggregateUDFImpl for PercentileCont { } } -#[derive(Clone, Copy)] -enum PercentileRewriteTarget { - Min, - Max, +fn get_percentile(args: &AccumulatorArgs) -> Result { + let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?; + + let is_descending = args + .order_bys + .first() + .map(|sort_expr| sort_expr.options.descending) + .unwrap_or(false); + + let percentile = if is_descending { + 1.0 - percentile + } else { + percentile + }; + + Ok(percentile) } -#[expect(clippy::needless_pass_by_value)] fn simplify_percentile_cont_aggregate( aggregate_function: AggregateFunction, info: &dyn SimplifyInfo, ) -> Result { - let original_expr = Expr::AggregateFunction(aggregate_function.clone()); - let params = &aggregate_function.params; + enum PercentileRewriteTarget { + Min, + Max, + } + let params = &aggregate_function.params; let [value, percentile] = take_function_args("percentile_cont", ¶ms.args)?; + // + // For simplicity we don't bother with null types (otherwise we'd need to + // cast the return type) + let input_type = info.get_data_type(value)?; + if input_type.is_null() { + return Ok(Expr::AggregateFunction(aggregate_function)); + } let is_descending = params .order_by @@ -401,43 +332,24 @@ fn simplify_percentile_cont_aggregate( .map(|sort| !sort.asc) .unwrap_or(false); - let rewrite_target = match extract_percentile_literal(percentile) { - Some(0.0) => { + let rewrite_target = match percentile { + Expr::Literal(ScalarValue::Float64(Some(0.0)), _) => { if is_descending { PercentileRewriteTarget::Max } else { PercentileRewriteTarget::Min } } - Some(1.0) => { + Expr::Literal(ScalarValue::Float64(Some(1.0)), _) => { if is_descending { PercentileRewriteTarget::Min } else { PercentileRewriteTarget::Max } } - _ => return Ok(original_expr), + _ => return Ok(Expr::AggregateFunction(aggregate_function)), }; - let input_type = match info.get_data_type(value) { - Ok(data_type) => data_type, - Err(_) => return Ok(original_expr), - }; - - let expected_return_type = - match percentile_cont_udaf().return_type(std::slice::from_ref(&input_type)) { - Ok(data_type) => data_type, - Err(_) => return Ok(original_expr), - }; - - let mut agg_arg = value.clone(); - if expected_return_type != input_type { - // min/max return the same type as their input. percentile_cont widens - // integers to Float64 (and preserves float/decimal types), so ensure the - // rewritten aggregate sees an input of the final return type. - agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), expected_return_type.clone())); - } - let udaf = match rewrite_target { PercentileRewriteTarget::Min => min_udaf(), PercentileRewriteTarget::Max => max_udaf(), @@ -445,7 +357,7 @@ fn simplify_percentile_cont_aggregate( let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf( udaf, - vec![agg_arg], + vec![value.clone()], params.distinct, params.filter.clone(), vec![], @@ -454,13 +366,6 @@ fn simplify_percentile_cont_aggregate( Ok(rewritten) } -fn extract_percentile_literal(expr: &Expr) -> Option { - match expr { - Expr::Literal(ScalarValue::Float64(Some(value)), _) => Some(*value), - _ => None, - } -} - /// The percentile_cont accumulator accumulates the raw input values /// as native types. /// @@ -468,23 +373,22 @@ fn extract_percentile_literal(expr: &Expr) -> Option { /// `merge_batch` and a `Vec` of native values that are converted to scalar values /// in the final evaluation step so that we avoid expensive conversions and /// allocations during `update_batch`. -struct PercentileContAccumulator { - data_type: DataType, +#[derive(Debug)] +struct PercentileContAccumulator { all_values: Vec, percentile: f64, } -impl Debug for PercentileContAccumulator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "PercentileContAccumulator({}, percentile={})", - self.data_type, self.percentile - ) +impl PercentileContAccumulator { + fn new(percentile: f64) -> Self { + Self { + all_values: vec![], + percentile, + } } } -impl Accumulator for PercentileContAccumulator { +impl Accumulator for PercentileContAccumulator { fn state(&mut self) -> Result> { // Convert `all_values` to `ListArray` and return a single List ScalarValue @@ -496,12 +400,11 @@ impl Accumulator for PercentileContAccumulator { let values_array = PrimitiveArray::::new( ScalarBuffer::from(std::mem::take(&mut self.all_values)), None, - ) - .with_data_type(self.data_type.clone()); + ); // Build the result list array let list_array = ListArray::new( - Arc::new(Field::new_list_field(self.data_type.clone(), true)), + Arc::new(Field::new_list_field(T::DATA_TYPE, true)), offsets, Arc::new(values_array), None, @@ -511,14 +414,7 @@ impl Accumulator for PercentileContAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // Cast to target type if needed (e.g., integer to Float64) - let values = if values[0].data_type() != &self.data_type { - arrow::compute::cast(&values[0], &self.data_type)? - } else { - Arc::clone(&values[0]) - }; - - let values = values.as_primitive::(); + let values = values[0].as_primitive::(); self.all_values.reserve(values.len() - values.null_count()); self.all_values.extend(values.iter().flatten()); Ok(()) @@ -526,16 +422,14 @@ impl Accumulator for PercentileContAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let array = states[0].as_list::(); - for v in array.iter().flatten() { - self.update_batch(&[v])? - } + self.update_batch(&[array.value(0)])?; Ok(()) } fn evaluate(&mut self) -> Result { let d = std::mem::take(&mut self.all_values); let value = calculate_percentile::(d, self.percentile); - ScalarValue::new_primitive::(value, &self.data_type) + ScalarValue::new_primitive::(value, &T::DATA_TYPE) } fn size(&self) -> usize { @@ -551,16 +445,14 @@ impl Accumulator for PercentileContAccumulator { /// will be actually organized as a `Vec>`. #[derive(Debug)] struct PercentileContGroupsAccumulator { - data_type: DataType, group_values: Vec>, percentile: f64, } impl PercentileContGroupsAccumulator { - pub fn new(data_type: DataType, percentile: f64) -> Self { + fn new(percentile: f64) -> Self { Self { - data_type, - group_values: Vec::new(), + group_values: vec![], percentile, } } @@ -579,14 +471,7 @@ impl GroupsAccumulator // For ordered-set aggregates, we only care about the ORDER BY column (first element) // The percentile parameter is already stored in self.percentile - // Cast to target type if needed (e.g., integer to Float64) - let values_array = if values[0].data_type() != &self.data_type { - arrow::compute::cast(&values[0], &self.data_type)? - } else { - Arc::clone(&values[0]) - }; - - let values = values_array.as_primitive::(); + let values = values[0].as_primitive::(); // Push the `not nulls + not filtered` row into its group self.group_values.resize(total_num_groups, Vec::new()); @@ -649,12 +534,11 @@ impl GroupsAccumulator let flatten_group_values = emit_group_values.into_iter().flatten().collect::>(); let group_values_array = - PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None) - .with_data_type(self.data_type.clone()); + PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None); // Build the result list array let result_list_array = ListArray::new( - Arc::new(Field::new_list_field(self.data_type.clone(), true)), + Arc::new(Field::new_list_field(T::DATA_TYPE, true)), offsets, Arc::new(group_values_array), None, @@ -669,7 +553,7 @@ impl GroupsAccumulator // Calculate percentile for each group let mut evaluate_result_builder = - PrimitiveBuilder::::new().with_data_type(self.data_type.clone()); + PrimitiveBuilder::::with_capacity(emit_group_values.len()); for values in emit_group_values { let value = calculate_percentile::(values, self.percentile); evaluate_result_builder.append_option(value); @@ -685,14 +569,7 @@ impl GroupsAccumulator ) -> Result> { assert_eq!(values.len(), 1, "one argument to merge_batch"); - // Cast to target type if needed (e.g., integer to Float64) - let values_array = if values[0].data_type() != &self.data_type { - arrow::compute::cast(&values[0], &self.data_type)? - } else { - Arc::clone(&values[0]) - }; - - let input_array = values_array.as_primitive::(); + let input_array = values[0].as_primitive::(); // Directly convert the input array to states, each row will be // seen as a respective group. @@ -702,8 +579,7 @@ impl GroupsAccumulator // to null. // Reuse values buffer in `input_array` to build `values` in `ListArray` - let values = PrimitiveArray::::new(input_array.values().clone(), None) - .with_data_type(self.data_type.clone()); + let values = PrimitiveArray::::new(input_array.values().clone(), None); // `offsets` in `ListArray`, each row as a list element let offset_end = i32::try_from(input_array.len()).map_err(|e| { @@ -724,7 +600,7 @@ impl GroupsAccumulator let nulls = filtered_null_mask(opt_filter, input_array); let converted_list_array = ListArray::new( - Arc::new(Field::new_list_field(self.data_type.clone(), true)), + Arc::new(Field::new_list_field(T::DATA_TYPE, true)), offsets, Arc::new(values), nulls, @@ -750,10 +626,18 @@ impl GroupsAccumulator #[derive(Debug)] struct DistinctPercentileContAccumulator { distinct_values: GenericDistinctBuffer, - data_type: DataType, percentile: f64, } +impl DistinctPercentileContAccumulator { + fn new(percentile: f64) -> Self { + Self { + distinct_values: GenericDistinctBuffer::new(T::DATA_TYPE), + percentile, + } + } +} + impl Accumulator for DistinctPercentileContAccumulator { fn state(&mut self) -> Result> { self.distinct_values.state() @@ -773,7 +657,7 @@ impl Accumulator for DistinctPercentileContAccumula .map(|v| v.0) .collect::>(); let value = calculate_percentile::(d, self.percentile); - ScalarValue::new_primitive::(value, &self.data_type) + ScalarValue::new_primitive::(value, &T::DATA_TYPE) } fn size(&self) -> usize { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 2a4daeb92979d..f6ce68917e03b 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -8241,3 +8241,8 @@ NULL NULL NULL NULL statement ok drop table distinct_avg; + +query R +select percentile_cont(null, 0.5); +---- +NULL