Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1d31dfe
Merge branch 'main' into udaf-schema-16997
kosiew Aug 8, 2025
66ee0c5
Merge branch 'main' into udaf-schema-16997
kosiew Aug 12, 2025
15e89aa
feat: Implement SchemaBasedAggregateUdf and enhance accumulator argum…
kosiew Aug 7, 2025
fea8c1c
refactor: Rename test function for schema-based aggregate UDF metadata
kosiew Aug 12, 2025
c386ba0
docs(aggregate): clarify AccumulatorArgs schema handling and usage
kosiew Aug 12, 2025
c8678e5
refactor: Extract AccumulatorArgs construction into a separate method…
kosiew Aug 12, 2025
25df8c6
test: Add unit tests for AggregateUDF implementation and args_schema …
kosiew Aug 12, 2025
2991acc
refactor: Consolidate use statements for improved readability in aggr…
kosiew Aug 12, 2025
664367b
refactor(tests): Simplify argument passing in AggregateExprBuilder tests
kosiew Aug 12, 2025
20d7a93
docs: Mark code examples as ignored in AggregateUDF and AccumulatorAr…
kosiew Aug 12, 2025
99824bd
Merge branch 'main' into udaf-schema-16997
kosiew Aug 19, 2025
8692423
Enhance DummyUdf struct by deriving PartialEq, Eq, and Hash traits
kosiew Aug 19, 2025
f2a2d51
Enhance SchemaBasedAggregateUdf struct by deriving PartialEq, Eq, and…
kosiew Aug 19, 2025
f4484be
Merge branch 'main' into udaf-schema-16997
kosiew Aug 21, 2025
35014e0
Merge branch 'main' into udaf-schema-16997
kosiew Aug 27, 2025
9e166ac
feat: refactor DummyUdf initialization and enhance args_schema handling
kosiew Aug 27, 2025
0cd5d33
feat(aggregate): refactor accumulator argument building for clarity
kosiew Aug 27, 2025
03579a1
refactor(tests): reorganize and enhance DummyUdf implementation and t…
kosiew Aug 27, 2025
7175121
Merge branch 'main' into udaf-schema-16997
kosiew Aug 27, 2025
4dd8bb2
Revert to last good point
kosiew Aug 27, 2025
fee3fe9
Merge branch 'main' into udaf-schema-16997
kosiew Aug 31, 2025
b5a931b
Merge branch 'main' into udaf-schema-16997
kosiew Sep 6, 2025
f6bb7ec
docs(udaf): improve documentation for AccumulatorArgs usage in Aggreg…
kosiew Oct 1, 2025
f5a36ee
refactor(tests): move tests to bottom
kosiew Oct 1, 2025
07f6fab
Merge branch 'main' into udaf-schema-16997
kosiew Oct 1, 2025
34b2d83
docs(udaf, accumulator): enhance documentation for AccumulatorArgs an…
kosiew Oct 10, 2025
c36b3aa
Clarify AccumulatorArgs documentation
kosiew Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,51 @@ impl Accumulator for MetadataBasedAccumulator {
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
struct SchemaBasedAggregateUdf {
signature: Signature,
}

impl SchemaBasedAggregateUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for SchemaBasedAggregateUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"schema_based_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::UInt64)
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let field = acc_args.schema.field(0).clone();
let double_output = field
.metadata()
.get("modify_values")
.map(|v| v == "double_output")
.unwrap_or(false);

Ok(Box::new(MetadataBasedAccumulator {
double_output,
curr_sum: 0,
}))
}
}

#[tokio::test]
async fn test_metadata_based_aggregate() -> Result<()> {
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
Expand Down Expand Up @@ -1166,3 +1211,34 @@ async fn test_metadata_based_aggregate_as_window() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_schema_based_aggregate_udf_metadata() -> Result<()> {
use datafusion_expr::{expr::FieldMetadata, lit_with_metadata};
use std::collections::BTreeMap;

let ctx = SessionContext::new();
let udf = AggregateUDF::from(SchemaBasedAggregateUdf::new());

let metadata = FieldMetadata::new(BTreeMap::from([(
"modify_values".to_string(),
"double_output".to_string(),
)]));

let expr = udf
.call(vec![lit_with_metadata(
ScalarValue::UInt64(Some(1)),
Some(metadata),
)])
.alias("res");

let plan = LogicalPlanBuilder::empty(true)
.aggregate(Vec::<Expr>::new(), vec![expr])?
.build()?;

let df = DataFrame::new(ctx.state(), plan);
let batches = df.collect().await?;
let array = batches[0].column(0).as_primitive::<UInt64Type>();
assert_eq!(array.value(0), 2);
Ok(())
}
28 changes: 27 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,33 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
/// group during query execution.
///
/// acc_args: [`AccumulatorArgs`] contains information about how the
/// aggregate function was called.
/// aggregate function was called. `acc_args.schema` exposes the physical
/// input schema (the same one used by the expressions), whereas
/// [`AccumulatorArgs::input_field`] / [`AccumulatorArgs::input_fields`]
/// wrap `acc_args.exprs[i].return_field(&acc_args.schema)?` to yield the
/// effective [`FieldRef`]s for the aggregate arguments.
///
/// In practice this is useful when you need field metadata or the exact
/// return field (name, type, metadata) for the argument. For example:
/// ```ignore
/// // get Arrow field metadata for the i-th input
/// let metadata = acc_args.schema.field(i).metadata();
/// // obtain the complete FieldRef (name, type, nullability, metadata)
/// // for the i-th argument using the convenience helper.
/// let field = acc_args.input_field(i)?;
/// ```
///
/// Multi-argument functions: `exprs[i]` are evaluated against the physical
/// schema. Mixed inputs (columns and literals): the physical input schema is
/// used when not empty; `acc_args.schema` is only synthesized from literals
/// when the physical schema is empty (so expressions that rely on column
/// lookups continue to work).
///
/// When an
/// aggregate is invoked with literal values only, `acc_args.schema` is
/// synthesized from those literals so that any field metadata (for
/// example Arrow extension types) is available to the accumulator
/// implementation.
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;

/// Return the fields used to store the intermediate state of this accumulator.
Expand Down
81 changes: 78 additions & 3 deletions datafusion/functions-aggregate-common/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use arrow::datatypes::{DataType, FieldRef, Schema};
use datafusion_common::Result;
use datafusion_common::{internal_err, Result};
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
Expand All @@ -30,7 +30,52 @@ pub struct AccumulatorArgs<'a> {
/// The return field of the aggregate function.
pub return_field: FieldRef,

/// The schema of the input arguments
/// The physical schema of the record batches fed into the aggregate.
///
/// This is the same schema that [`exprs`] expect when resolving column
/// references. For column-only aggregates the physical schema and the
/// "effective" argument fields match one-to-one, so
/// `acc_args.schema.field(i)` and
/// `acc_args.exprs[i].return_field(&acc_args.schema)?` return equivalent
/// metadata. For expressions that reference multiple columns (e.g.
/// `SUM(a + b)`) the physical schema still contains the full input
/// (`[a, b, …]`), while `return_field` synthesises the single argument
/// field that the accumulator consumes. Both views are therefore exposed:
///
/// * Use [`Self::schema`] to inspect the raw physical fields, including
/// metadata coming from the child plan.
/// * Use [`Self::exprs`] in combination with [`PhysicalExpr::return_field`]
/// to recover the effective [`FieldRef`] for each aggregate argument.
///
/// When an aggregate is invoked with only literal values, the physical
/// schema is empty. In that case DataFusion synthesises a schema from the
/// literal expressions so extension metadata is still available. In mixed
/// column and literal inputs the existing physical schema takes precedence;
/// synthesized metadata is only used when the physical schema is empty.
///
/// For convenience, [`Self::input_field`] and [`Self::input_fields`] wrap
/// the `exprs[i].return_field(&schema)` pattern so UDAF implementations can
/// recover the effective [`FieldRef`]s without interacting with the raw
/// schema directly, matching the ergonomics of other function argument
/// structs.
///
/// ### Relation to other function argument structs
///
/// Scalar and window functions see arguments *after* their `PhysicalExpr`s
/// have been evaluated. As a consequence, [`ScalarFunctionArgs`](datafusion_expr::ScalarFunctionArgs) and
/// [`PartitionEvaluatorArgs`](datafusion_functions_window_common::partition::PartitionEvaluatorArgs)
/// only expose the already-computed [`FieldRef`]s for each argument. In
/// contrast, an accumulator is constructed before any batches have been
/// processed, so the planner cannot simply hand the UDAF pre-evaluated
/// arguments. Instead we provide both the physical schema and the original
/// argument expressions; the accumulator can then call
/// [`PhysicalExpr::return_field`] when it needs the logical argument field.
///
/// This is why expressions such as `SUM(a + b)` require the schema even
/// though scalar functions also support `SIN(a + b)`. The scalar version
/// receives the evaluated `a + b` as a [`ColumnarValue`](datafusion_expr::ColumnarValue), while the
/// aggregate still holds the unevaluated `PhysicalExpr` and must resolve it
/// against the physical schema when computing its metadata.
pub schema: &'a Schema,

/// Whether to ignore nulls.
Expand Down Expand Up @@ -65,7 +110,16 @@ pub struct AccumulatorArgs<'a> {
/// ```
pub is_distinct: bool,

/// The physical expression of arguments the aggregate function takes.
/// The physical expressions for the aggregate function's arguments.
/// Use these expressions together with [`Self::schema`] to obtain the
/// [`FieldRef`] of each input via
/// [`PhysicalExpr::return_field`](`PhysicalExpr::return_field`).
///
/// Example:
/// ```ignore
/// let input_field = exprs[i].return_field(&schema)?;
/// ```
/// Note: physical schema metadata takes precedence in mixed inputs.
pub exprs: &'a [Arc<dyn PhysicalExpr>],
}

Expand All @@ -74,6 +128,27 @@ impl AccumulatorArgs<'_> {
pub fn return_type(&self) -> &DataType {
self.return_field.data_type()
}

/// Returns the [`FieldRef`] corresponding to the `index`th aggregate
/// argument.
pub fn input_field(&self, index: usize) -> Result<FieldRef> {
let expr = self.exprs.get(index).ok_or_else(|| {
internal_err!(
"input_field index {index} is out of bounds for {} arguments",
self.exprs.len()
)
})?;

expr.return_field(self.schema)
}

/// Returns [`FieldRef`]s for all aggregate arguments in order.
pub fn input_fields(&self) -> Result<Vec<FieldRef>> {
self.exprs
.iter()
.map(|expr| expr.return_field(self.schema))
.collect()
}
}

/// Factory that returns an accumulator for the given aggregate function.
Expand Down
Loading
Loading