Skip to content
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1d31dfe
Merge branch 'main' into udaf-schema-16997
kosiew Aug 8, 2025
66ee0c5
Merge branch 'main' into udaf-schema-16997
kosiew Aug 12, 2025
15e89aa
feat: Implement SchemaBasedAggregateUdf and enhance accumulator argum…
kosiew Aug 7, 2025
fea8c1c
refactor: Rename test function for schema-based aggregate UDF metadata
kosiew Aug 12, 2025
c386ba0
docs(aggregate): clarify AccumulatorArgs schema handling and usage
kosiew Aug 12, 2025
c8678e5
refactor: Extract AccumulatorArgs construction into a separate method…
kosiew Aug 12, 2025
25df8c6
test: Add unit tests for AggregateUDF implementation and args_schema …
kosiew Aug 12, 2025
2991acc
refactor: Consolidate use statements for improved readability in aggr…
kosiew Aug 12, 2025
664367b
refactor(tests): Simplify argument passing in AggregateExprBuilder tests
kosiew Aug 12, 2025
20d7a93
docs: Mark code examples as ignored in AggregateUDF and AccumulatorAr…
kosiew Aug 12, 2025
99824bd
Merge branch 'main' into udaf-schema-16997
kosiew Aug 19, 2025
8692423
Enhance DummyUdf struct by deriving PartialEq, Eq, and Hash traits
kosiew Aug 19, 2025
f2a2d51
Enhance SchemaBasedAggregateUdf struct by deriving PartialEq, Eq, and…
kosiew Aug 19, 2025
f4484be
Merge branch 'main' into udaf-schema-16997
kosiew Aug 21, 2025
35014e0
Merge branch 'main' into udaf-schema-16997
kosiew Aug 27, 2025
9e166ac
feat: refactor DummyUdf initialization and enhance args_schema handling
kosiew Aug 27, 2025
0cd5d33
feat(aggregate): refactor accumulator argument building for clarity
kosiew Aug 27, 2025
03579a1
refactor(tests): reorganize and enhance DummyUdf implementation and t…
kosiew Aug 27, 2025
7175121
Merge branch 'main' into udaf-schema-16997
kosiew Aug 27, 2025
4dd8bb2
Revert to last good point
kosiew Aug 27, 2025
fee3fe9
Merge branch 'main' into udaf-schema-16997
kosiew Aug 31, 2025
b5a931b
Merge branch 'main' into udaf-schema-16997
kosiew Sep 6, 2025
f6bb7ec
docs(udaf): improve documentation for AccumulatorArgs usage in Aggreg…
kosiew Oct 1, 2025
f5a36ee
refactor(tests): move tests to bottom
kosiew Oct 1, 2025
07f6fab
Merge branch 'main' into udaf-schema-16997
kosiew Oct 1, 2025
34b2d83
docs(udaf, accumulator): enhance documentation for AccumulatorArgs an…
kosiew Oct 10, 2025
c36b3aa
Clarify AccumulatorArgs documentation
kosiew Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,51 @@ impl Accumulator for MetadataBasedAccumulator {
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
struct SchemaBasedAggregateUdf {
signature: Signature,
}

impl SchemaBasedAggregateUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for SchemaBasedAggregateUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"schema_based_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::UInt64)
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let field = acc_args.schema.field(0).clone();
let double_output = field
.metadata()
.get("modify_values")
.map(|v| v == "double_output")
.unwrap_or(false);

Ok(Box::new(MetadataBasedAccumulator {
double_output,
curr_sum: 0,
}))
}
}

#[tokio::test]
async fn test_metadata_based_aggregate() -> Result<()> {
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
Expand Down Expand Up @@ -1166,3 +1211,34 @@ async fn test_metadata_based_aggregate_as_window() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_schema_based_aggregate_udf_metadata() -> Result<()> {
use datafusion_expr::{expr::FieldMetadata, lit_with_metadata};
use std::collections::BTreeMap;

let ctx = SessionContext::new();
let udf = AggregateUDF::from(SchemaBasedAggregateUdf::new());

let metadata = FieldMetadata::new(BTreeMap::from([(
"modify_values".to_string(),
"double_output".to_string(),
)]));

let expr = udf
.call(vec![lit_with_metadata(
ScalarValue::UInt64(Some(1)),
Some(metadata),
)])
.alias("res");

let plan = LogicalPlanBuilder::empty(true)
.aggregate(Vec::<Expr>::new(), vec![expr])?
.build()?;

let df = DataFrame::new(ctx.state(), plan);
let batches = df.collect().await?;
let array = batches[0].column(0).as_primitive::<UInt64Type>();
assert_eq!(array.value(0), 2);
Ok(())
}
19 changes: 18 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,24 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
/// group during query execution.
///
/// acc_args: [`AccumulatorArgs`] contains information about how the
/// aggregate function was called.
/// aggregate function was called. Use `acc_args.exprs` together with
/// `acc_args.schema` to inspect the [`FieldRef`] of each input.
///
/// Example: retrieving metadata and return field for input `i`:
/// ```ignore
/// let metadata = acc_args.schema.field(i).metadata();
/// let field = acc_args.exprs[i].return_field(&acc_args.schema)?;
/// ```

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having some trouble understanding this example; I can understand the part for getting the metadata of a field given the context of the PR, but why do we also include an example for getting the return field?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The snippet is meant to illustrate the sentence immediately above it: you pair acc_args.exprs with acc_args.schema to recover the full FieldRef for argument i.

Pulling the metadata out of schema.field(i) is one common use case, and the follow-up line shows how you would then obtain the complete FieldRef (name, type, metadata) via:

exprs[i].return_field(&acc_args.schema)

...using the same pairing.

I'll tweak the wording.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The snippet is meant to illustrate the sentence immediately above it: you pair acc_args.exprs with acc_args.schema to recover the full FieldRef for argument i.

This may be a silly question, but what's the difference between acc_args.exprs[i].return_field(&acc_args.schema) and acc_args.schema.field(i)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at all 😄

  • acc_args.schema.field(i) — returns the raw Arrow Field from the (physical) input schema at position i (name, type, nullability, metadata exactly as in that schema).

  • acc_args.exprs[i].return_field(&acc_args.schema)? — asks the expression for the effective FieldRef for argument i given the full schema. It incorporates expression semantics (casts, literals, computed types, extension metadata, nullability changes, etc.) and returns an owned/Arc FieldRef (and can fail), not just a borrowed &Field.

/// Multi-argument functions: `exprs[i]` corresponds to `schema.field(i)`.
/// Mixed inputs (columns and literals): the physical input schema is used
/// when not empty; `acc_args.schema` is only synthesized from literals when
/// the physical schema is empty.
///
/// When an
/// aggregate is invoked with literal values only, `acc_args.schema` is
/// synthesized from those literals so that any field metadata (for
/// example Arrow extension types) is available to the accumulator
/// implementation.
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;

/// Return the fields used to store the intermediate state of this accumulator.
Expand Down
19 changes: 17 additions & 2 deletions datafusion/functions-aggregate-common/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ pub struct AccumulatorArgs<'a> {
/// The return field of the aggregate function.
pub return_field: FieldRef,

/// The schema of the input arguments
/// The schema of the input arguments.
///
/// This schema contains the fields corresponding to the function’s input
/// expressions (`exprs`). When an aggregate is invoked with only literal
/// values, this schema is synthesized from those literals to preserve
/// field-level metadata (such as Arrow extension types). In mixed column
/// and literal inputs, metadata from the physical schema takes precedence;
/// synthesized metadata is only used when the physical schema is empty.
pub schema: &'a Schema,

/// Whether to ignore nulls.
Expand Down Expand Up @@ -65,7 +72,15 @@ pub struct AccumulatorArgs<'a> {
/// ```
pub is_distinct: bool,

/// The physical expression of arguments the aggregate function takes.
/// The physical expressions for the aggregate function's arguments.
/// Use these expressions together with [`Self::schema`] to obtain the
/// [`FieldRef`] of each input via `expr.return_field(schema)`.
///
/// Example:
/// ```ignore
/// let input_field = exprs[i].return_field(&schema)?;
/// ```
/// Note: physical schema metadata takes precedence in mixed inputs.
pub exprs: &'a [Arc<dyn PhysicalExpr>],
}

Expand Down
198 changes: 145 additions & 53 deletions datafusion/physical-expr/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,92 @@ pub(crate) mod groups_accumulator {
accumulate::NullState, GroupsAccumulatorAdapter,
};
}

#[cfg(test)]
Comment thread
Jefffrey marked this conversation as resolved.
Outdated
mod tests {
Comment thread
Jefffrey marked this conversation as resolved.
Outdated
use super::*;
use crate::expressions::Literal;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::ScalarValue;
use datafusion_expr::{AggregateUDF, AggregateUDFImpl, Signature, Volatility};
use std::{any::Any, sync::Arc};
#[derive(Debug, PartialEq, Eq, Hash)]
struct DummyUdf {
signature: Signature,
}

impl DummyUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for DummyUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"dummy"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
Ok(DataType::UInt64)
}
fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!()
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
unimplemented!()
}
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!()
}
}

#[test]
fn test_args_schema_and_groups_path() {
// literal-only: empty physical schema synthesizes schema from literal expr
let udf = Arc::new(AggregateUDF::from(DummyUdf::new()));
let lit_expr =
Arc::new(Literal::new(ScalarValue::UInt32(Some(1)))) as Arc<dyn PhysicalExpr>;
let agg =
AggregateExprBuilder::new(Arc::clone(&udf), vec![Arc::clone(&lit_expr)])
.alias("x")
.schema(Arc::new(Schema::empty()))
.build()
.unwrap();
match agg.args_schema() {
Cow::Owned(s) => assert_eq!(s.field(0).name(), "lit"),
_ => panic!("expected owned schema"),
}
assert!(agg.groups_accumulator_supported());

// non-empty physical schema should be borrowed
let f = Field::new("b", DataType::Int32, false);
let phys_schema = Schema::new(vec![f.clone()]);
let col_expr = Arc::new(Column::new("b", 0)) as Arc<dyn PhysicalExpr>;
let agg2 = AggregateExprBuilder::new(udf, vec![col_expr])
.alias("x")
.schema(Arc::new(phys_schema))
.build()
.unwrap();
match agg2.args_schema() {
Cow::Borrowed(s) => assert_eq!(s.field(0).name(), "b"),
_ => panic!("expected borrowed schema"),
}
assert!(agg2.groups_accumulator_supported());
}
}
pub(crate) mod stats {
pub use datafusion_functions_aggregate_common::stats::StatsType;
}
Expand All @@ -33,25 +119,25 @@ pub mod utils {
DecimalAverager, Hashable,
};
}

use std::fmt::Debug;
use std::sync::Arc;

use crate::expressions::Column;

use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef};
use arrow::{
compute::SortOptions,
datatypes::{DataType, FieldRef, Schema, SchemaRef},
};
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity};
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_expr_common::groups_accumulator::GroupsAccumulator;
use datafusion_expr_common::type_coercion::aggregates::check_arg_count;
use datafusion_functions_aggregate_common::accumulator::{
AccumulatorArgs, StateFieldsArgs,
use datafusion_expr_common::{
accumulator::Accumulator, groups_accumulator::GroupsAccumulator,
type_coercion::aggregates::check_arg_count,
};
use datafusion_functions_aggregate_common::{
accumulator::{AccumulatorArgs, StateFieldsArgs},
order::AggregateOrderSensitivity,
};
use datafusion_physical_expr_common::{
physical_expr::PhysicalExpr, sort_expr::PhysicalSortExpr,
};
use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use std::{borrow::Cow, fmt::Debug, sync::Arc};

/// Builder for physical [`AggregateFunctionExpr`]
///
Expand Down Expand Up @@ -376,21 +462,52 @@ impl AggregateFunctionExpr {
.into()
}

/// the accumulator used to accumulate values from the expressions.
/// the accumulator expects the same number of arguments as `expressions` and must
/// return states with the same description as `state_fields`
pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let acc_args = AccumulatorArgs {
/// Returns a schema containing the fields corresponding to this
/// aggregate's input expressions in the same order as `input_fields`/`exprs`.
///
/// If the physical input schema is empty (literal-only inputs),
/// synthesizes a new schema from the literal expressions to preserve
/// field-level metadata (such as Arrow extension types).
/// Field order is guaranteed to match the order of input expressions.
/// In mixed column and literal inputs, existing physical schema fields
/// win; synthesized metadata is only applied when the physical schema
/// has no fields.
///
/// Uses [`std::borrow::Cow`] to avoid allocation when the existing
/// schema is non-empty. For micro-optimizations, implementers may
/// cache the owned schema if multiple calls are made per instance.
fn args_schema(&self) -> Cow<'_, Schema> {
if self.schema.fields().is_empty() {
Cow::Owned(Schema::new(
self.input_fields
.iter()
.map(|f| f.as_ref().clone())
.collect::<Vec<_>>(),
))
} else {
Cow::Borrowed(&self.schema)
}
}
/// Construct AccumulatorArgs for this aggregate using a given schema slice.
fn make_acc_args<'a>(&'a self, schema: &'a Schema) -> AccumulatorArgs<'a> {
AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
schema,
ignore_nulls: self.ignore_nulls,
order_bys: self.order_bys.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
}
}

/// the accumulator used to accumulate values from the expressions.
/// the accumulator expects the same number of arguments as `expressions` and must
/// return states with the same description as `state_fields`
pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let schema = self.args_schema();
let acc_args = self.make_acc_args(schema.as_ref());
self.fun.accumulator(acc_args)
}

Expand Down Expand Up @@ -464,17 +581,8 @@ impl AggregateFunctionExpr {

/// Creates accumulator implementation that supports retract
pub fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let args = AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
order_bys: self.order_bys.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};

let schema = self.args_schema();
let args = self.make_acc_args(schema.as_ref());
let accumulator = self.fun.create_sliding_accumulator(args)?;

// Accumulators that have window frame startings different
Expand Down Expand Up @@ -533,16 +641,8 @@ impl AggregateFunctionExpr {
/// [`GroupsAccumulator`] implementation. If this returns true,
/// `[Self::create_groups_accumulator`] will be called.
pub fn groups_accumulator_supported(&self) -> bool {
let args = AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
order_bys: self.order_bys.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
let schema = self.args_schema();
let args = self.make_acc_args(schema.as_ref());
self.fun.groups_accumulator_supported(args)
}

Expand All @@ -552,16 +652,8 @@ impl AggregateFunctionExpr {
/// For maximum performance, a [`GroupsAccumulator`] should be
/// implemented in addition to [`Accumulator`].
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
let args = AccumulatorArgs {
return_field: Arc::clone(&self.return_field),
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
order_bys: self.order_bys.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
let schema = self.args_schema();
let args = self.make_acc_args(schema.as_ref());
self.fun.create_groups_accumulator(args)
}

Expand Down