From e7282e679dae8714fc4521ee83e0044d5234a968 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Sat, 7 Feb 2026 21:17:07 -0800 Subject: [PATCH 1/9] feat(lance): add Substrait aggregate support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for aggregates via Substrait AggregateRel specification. Key changes: - Add `AggregateSpec` enum with Substrait and Datafusion variants - Add `aggregate_substrait()` and `aggregate_expr()` methods to Scanner - Add `create_aggregate_plan()` to build execution plan with AggregateExec - Add Substrait parsing utilities in lance-datafusion for AggregateRel - Implement type coercion for UserDefined signature functions (e.g., AVG) - Support output column aliases via RelRoot.names Supported: COUNT, SUM, AVG, MIN, MAX with GROUP BY. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- Cargo.lock | 1 + rust/lance-datafusion/src/substrait.rs | 483 ++++++++++- rust/lance/Cargo.toml | 1 + rust/lance/src/dataset/scanner.rs | 227 +++++ .../src/dataset/tests/dataset_aggregate.rs | 776 ++++++++++++++++++ rust/lance/src/dataset/tests/mod.rs | 2 + 6 files changed, 1489 insertions(+), 1 deletion(-) create mode 100644 rust/lance/src/dataset/tests/dataset_aggregate.rs diff --git a/Cargo.lock b/Cargo.lock index cc98925d15e..cfcc4899c96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4731,6 +4731,7 @@ dependencies = [ "datafusion-functions", "datafusion-physical-expr", "datafusion-physical-plan", + "datafusion-substrait", "deepsize", "dirs 5.0.1", "either", diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index 54dc9be8808..09630c396a6 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -3,6 +3,11 @@ use arrow_schema::Schema as ArrowSchema; use datafusion::{execution::SessionState, logical_expr::Expr}; +use datafusion_common::DFSchema; +use datafusion_substrait::extensions::Extensions; +use datafusion_substrait::logical_plan::consumer::{ + from_substrait_agg_func, from_substrait_rex, DefaultSubstraitConsumer, +}; use datafusion_substrait::substrait::proto::{ expression::{ field_reference::{ReferenceType, RootType}, @@ -11,7 +16,8 @@ use datafusion_substrait::substrait::proto::{ expression_reference::ExprType, function_argument::ArgType, r#type::{Kind, Struct}, - Expression, ExpressionReference, ExtendedExpression, NamedStruct, Type, + rel::RelType, + AggregateRel, Expression, ExpressionReference, ExtendedExpression, NamedStruct, Plan, Type, }; use lance_core::{Error, Result}; use prost::Message; @@ -324,6 +330,204 @@ pub async fn parse_substrait( Ok(expr_container.exprs.pop().unwrap().0) } +/// Aggregate specification with group by and aggregate expressions. +#[derive(Debug, Clone)] +pub struct Aggregate { + pub group_by: Vec, + pub aggregates: Vec, + /// Output column names in order: group_by columns first, then aggregates. + pub output_names: Vec, +} + +/// Parse Substrait Plan bytes containing an AggregateRel. +pub async fn parse_substrait_aggregate( + bytes: &[u8], + input_schema: Arc, + state: &SessionState, +) -> Result { + let plan = Plan::decode(bytes)?; + let (aggregate_rel, output_names) = extract_aggregate_from_plan(&plan)?; + let extensions = Extensions::try_from(&plan.extensions)?; + + let mut agg = + parse_aggregate_rel_with_extensions(&aggregate_rel, input_schema, state, &extensions) + .await?; + agg.output_names = output_names; + Ok(agg) +} + +fn extract_aggregate_from_plan(plan: &Plan) -> Result<(Box, Vec)> { + if plan.relations.is_empty() { + return Err(Error::invalid_input( + "Substrait Plan has no relations", + location!(), + )); + } + + let plan_rel = &plan.relations[0]; + let (rel, output_names) = match &plan_rel.rel_type { + Some(datafusion_substrait::substrait::proto::plan_rel::RelType::Root(root)) => { + (root.input.as_ref(), root.names.clone()) + } + Some(datafusion_substrait::substrait::proto::plan_rel::RelType::Rel(rel)) => { + (Some(rel), vec![]) + } + None => (None, vec![]), + }; + + let rel = rel.ok_or_else(|| Error::invalid_input("Plan relation has no input", location!()))?; + + match &rel.rel_type { + Some(RelType::Aggregate(agg)) => Ok((agg.clone(), output_names)), + Some(other) => Err(Error::invalid_input( + format!( + "Expected Substrait AggregateRel, got {:?}", + std::mem::discriminant(other) + ), + location!(), + )), + None => Err(Error::invalid_input( + "Substrait Rel has no rel_type", + location!(), + )), + } +} + +/// Parse an AggregateRel proto with provided extensions. +pub async fn parse_aggregate_rel_with_extensions( + aggregate_rel: &AggregateRel, + input_schema: Arc, + state: &SessionState, + extensions: &Extensions, +) -> Result { + let df_schema = DFSchema::try_from(input_schema.as_ref().clone())?; + let consumer = DefaultSubstraitConsumer::new(extensions, state); + let group_by = parse_groupings(aggregate_rel, &df_schema, &consumer).await?; + let aggregates = parse_measures(aggregate_rel, &df_schema, &consumer).await?; + + Ok(Aggregate { + group_by, + aggregates, + output_names: vec![], + }) +} + +/// Parse an AggregateRel proto with default extensions. +pub async fn parse_aggregate_rel( + aggregate_rel: &AggregateRel, + input_schema: Arc, + state: &SessionState, +) -> Result { + let extensions = Extensions::default(); + parse_aggregate_rel_with_extensions(aggregate_rel, input_schema, state, &extensions).await +} + +async fn parse_groupings( + agg_rel: &AggregateRel, + schema: &DFSchema, + consumer: &DefaultSubstraitConsumer<'_>, +) -> Result> { + let mut group_exprs = Vec::new(); + + // First, handle the new-style grouping_expressions + expression_references + if !agg_rel.grouping_expressions.is_empty() { + for grouping in &agg_rel.groupings { + for expr_ref in &grouping.expression_references { + let idx = *expr_ref as usize; + if idx >= agg_rel.grouping_expressions.len() { + return Err(Error::invalid_input( + format!( + "Grouping expression reference {} out of bounds (max: {})", + idx, + agg_rel.grouping_expressions.len() + ), + location!(), + )); + } + let expr = &agg_rel.grouping_expressions[idx]; + let df_expr = from_substrait_rex(consumer, expr, schema) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse grouping expression: {}", e), + location!(), + ) + })?; + group_exprs.push(df_expr); + } + } + } else { + // Fallback to deprecated inline grouping_expressions within each Grouping + #[allow(deprecated)] + for grouping in &agg_rel.groupings { + for expr in &grouping.grouping_expressions { + let df_expr = from_substrait_rex(consumer, expr, schema) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse grouping expression: {}", e), + location!(), + ) + })?; + group_exprs.push(df_expr); + } + } + } + + Ok(group_exprs) +} + +async fn parse_measures( + agg_rel: &AggregateRel, + schema: &DFSchema, + consumer: &DefaultSubstraitConsumer<'_>, +) -> Result> { + let mut aggregates = Vec::new(); + + for measure in &agg_rel.measures { + if let Some(agg_func) = &measure.measure { + // Parse optional filter + let filter = if let Some(filter_expr) = &measure.filter { + let df_filter = from_substrait_rex(consumer, filter_expr, schema) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse measure filter: {}", e), + location!(), + ) + })?; + Some(Box::new(df_filter)) + } else { + None + }; + + // Parse ordering (for ordered aggregates like ARRAY_AGG) + let order_by = Vec::new(); // TODO: parse agg_func.sorts if needed + + // Check for DISTINCT invocation + let distinct = matches!( + agg_func.invocation, + i if i == datafusion_substrait::substrait::proto::aggregate_function::AggregationInvocation::Distinct as i32 + ); + + // Convert Substrait AggregateFunction to DataFusion Expr + let df_expr = + from_substrait_agg_func(consumer, agg_func, schema, filter, order_by, distinct) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse aggregate function: {}", e), + location!(), + ) + })?; + + aggregates.push(df_expr.as_ref().clone()); + } + } + + Ok(aggregates) +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -607,4 +811,281 @@ mod tests { assert_substrait_roundtrip(schema, id_filter("test-id")).await; } + + // ==================== Aggregate parsing tests ==================== + + use datafusion_substrait::substrait::proto::{ + aggregate_function::AggregationInvocation, + aggregate_rel::{Grouping, Measure}, + rel::RelType, + AggregateFunction, AggregateRel, Plan, PlanRel, Rel, RelRoot, + }; + + /// Helper to create a field reference expression for a column index + fn agg_field_ref(field_index: i32) -> Expression { + Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField(Box::new( + StructField { + field: field_index, + child: None, + }, + ))), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }))), + } + } + + /// Create extension declaration for an aggregate function + fn agg_extension(anchor: u32, name: &str) -> SimpleExtensionDeclaration { + SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction { + #[allow(deprecated)] + extension_uri_reference: 1, + extension_urn_reference: 0, + function_anchor: anchor, + name: name.to_string(), + })), + } + } + + /// Helper to create a Substrait Plan with AggregateRel + fn create_aggregate_plan( + measures: Vec, + grouping_expressions: Vec, + groupings: Vec, + extensions: Vec, + ) -> Vec { + let aggregate_rel = AggregateRel { + common: None, + input: None, // Input is ignored for pushdown + groupings, + measures, + grouping_expressions, + advanced_extension: None, + }; + + let rel = Rel { + rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))), + }; + + // Wrap in a Plan to include extensions + let plan = Plan { + version: Some(Version { + major_number: 0, + minor_number: 63, + patch_number: 0, + git_hash: String::new(), + producer: "lance-test".to_string(), + }), + #[allow(deprecated)] + extension_uris: vec![SimpleExtensionUri { + extension_uri_anchor: 1, + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml".to_string(), + }], + extensions, + relations: vec![PlanRel { + rel_type: Some( + datafusion_substrait::substrait::proto::plan_rel::RelType::Root(RelRoot { + input: Some(rel), + names: vec![], + }), + ), + }], + advanced_extensions: None, + expected_type_urls: vec![], + extension_urns: vec![], + parameter_bindings: vec![], + type_aliases: vec![], + }; + + plan.encode_to_vec() + } + + /// Create a COUNT(*) measure + fn count_star_measure(function_ref: u32) -> Measure { + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![], + options: vec![], + output_type: None, + phase: 0, + sorts: vec![], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } + } + + /// Create a SUM/AVG/MIN/MAX measure on a column + fn simple_agg_measure(function_ref: u32, column_index: i32) -> Measure { + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(agg_field_ref(column_index))), + }], + options: vec![], + output_type: None, + phase: 0, + sorts: vec![], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } + } + + #[tokio::test] + async fn test_parse_substrait_aggregate_count_star() { + let bytes = create_aggregate_plan( + vec![count_star_measure(0)], + vec![], + vec![], + vec![agg_extension(0, "count")], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int64, true), + ])); + + let result = + crate::substrait::parse_substrait_aggregate(&bytes, schema, &session_state()).await; + + let agg = result.expect("Failed to parse COUNT(*) aggregate"); + assert!(agg.group_by.is_empty(), "COUNT(*) should have no group by"); + assert_eq!(agg.aggregates.len(), 1, "Should have exactly one aggregate"); + + // Verify it's a COUNT aggregate + let agg_expr = &agg.aggregates[0]; + assert!( + agg_expr.schema_name().to_string().contains("count"), + "Expected COUNT aggregate, got: {}", + agg_expr.schema_name() + ); + } + + #[tokio::test] + async fn test_parse_substrait_aggregate_sum() { + let bytes = create_aggregate_plan( + vec![simple_agg_measure(0, 1)], // SUM on column index 1 (y) + vec![], + vec![], + vec![agg_extension(0, "sum")], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int64, true), + ])); + + let result = + crate::substrait::parse_substrait_aggregate(&bytes, schema, &session_state()).await; + + let agg = result.expect("Failed to parse SUM aggregate"); + assert!(agg.group_by.is_empty(), "SUM should have no group by"); + assert_eq!(agg.aggregates.len(), 1, "Should have exactly one aggregate"); + + // Verify it's a SUM aggregate + let agg_expr = &agg.aggregates[0]; + assert!( + agg_expr.schema_name().to_string().contains("sum"), + "Expected SUM aggregate, got: {}", + agg_expr.schema_name() + ); + } + + #[tokio::test] + async fn test_parse_substrait_aggregate_sum_with_group_by() { + // SUM(y) GROUP BY x + let bytes = create_aggregate_plan( + vec![simple_agg_measure(0, 1)], // SUM on column index 1 (y) + vec![agg_field_ref(0)], // Group by column index 0 (x) + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], // Reference to first grouping_expression + }], + vec![agg_extension(0, "sum")], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int64, true), + ])); + + let result = + crate::substrait::parse_substrait_aggregate(&bytes, schema, &session_state()).await; + + let agg = result.expect("Failed to parse SUM with GROUP BY"); + assert_eq!( + agg.group_by.len(), + 1, + "Should have exactly one group by expression" + ); + assert_eq!(agg.aggregates.len(), 1, "Should have exactly one aggregate"); + + // Verify group by is column x + let group_expr = &agg.group_by[0]; + assert!( + group_expr.schema_name().to_string().contains('x'), + "Expected group by on column x, got: {}", + group_expr.schema_name() + ); + + // Verify it's a SUM aggregate + let agg_expr = &agg.aggregates[0]; + assert!( + agg_expr.schema_name().to_string().contains("sum"), + "Expected SUM aggregate, got: {}", + agg_expr.schema_name() + ); + } + + #[tokio::test] + async fn test_parse_substrait_aggregate_multiple_aggregates() { + // COUNT(*) and SUM(y) + let bytes = create_aggregate_plan( + vec![count_star_measure(0), simple_agg_measure(1, 1)], + vec![], + vec![], + vec![agg_extension(0, "count"), agg_extension(1, "sum")], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int64, true), + ])); + + let result = + crate::substrait::parse_substrait_aggregate(&bytes, schema, &session_state()).await; + + let agg = result.expect("Failed to parse multiple aggregates"); + assert!(agg.group_by.is_empty(), "Should have no group by"); + assert_eq!(agg.aggregates.len(), 2, "Should have two aggregates"); + + // Verify COUNT + assert!( + agg.aggregates[0] + .schema_name() + .to_string() + .contains("count"), + "Expected COUNT aggregate, got: {}", + agg.aggregates[0].schema_name() + ); + + // Verify SUM + assert!( + agg.aggregates[1].schema_name().to_string().contains("sum"), + "Expected SUM aggregate, got: {}", + agg.aggregates[1].schema_name() + ); + } } diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index d57187bcacf..427de2aa30e 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -116,6 +116,7 @@ aws-sdk-s3 = { workspace = true } geoarrow-array = { workspace = true } geoarrow-schema = { workspace = true } geo-types = { workspace = true } +datafusion-substrait = { workspace = true } [features] default = ["aws", "azure", "gcp", "oss", "huggingface", "tencent"] diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index a9265517fe5..b217b968462 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -462,6 +462,65 @@ impl ExprFilter { } } +/// Aggregate specification from Substrait or DataFusion expressions. +#[derive(Debug, Clone)] +pub enum AggregateSpec { + #[cfg(feature = "substrait")] + Substrait(Vec), + Datafusion { + group_by: Vec, + aggregates: Vec, + output_names: Vec, + }, +} + +#[cfg(feature = "substrait")] +impl AggregateSpec { + /// Converts to DataFusion expressions. + pub fn to_datafusion( + &self, + schema: Arc, + ) -> Result { + use lance_datafusion::exec::{get_session_context, LanceExecutionOptions}; + use lance_datafusion::substrait::{parse_substrait_aggregate, Aggregate}; + + match self { + Self::Substrait(bytes) => { + let ctx = get_session_context(&LanceExecutionOptions::default()); + parse_substrait_aggregate(bytes, schema, &ctx.state()) + .now_or_never() + .expect("could not parse the Substrait aggregate in a synchronous fashion") + } + Self::Datafusion { + group_by, + aggregates, + output_names, + } => Ok(Aggregate { + group_by: group_by.clone(), + aggregates: aggregates.clone(), + output_names: output_names.clone(), + }), + } + } +} + +#[cfg(not(feature = "substrait"))] +impl AggregateSpec { + /// Converts the aggregate specification to DataFusion expressions + pub fn to_datafusion( + &self, + _schema: Arc, + ) -> Result<(Vec, Vec, Vec)> { + match self { + Self::Datafusion { + group_by, + aggregates, + output_names, + } => Ok((group_by.clone(), aggregates.clone(), output_names.clone())), + } + } +} + /// Dataset Scanner /// /// ```rust,ignore @@ -570,6 +629,8 @@ pub struct Scanner { /// File reader options to use when reading data files. file_reader_options: Option, + aggregate: Option, + // Legacy fields to help migrate some old projection behavior to new behavior // // There are two behaviors we are moving away from: @@ -787,6 +848,7 @@ impl Scanner { scan_stats_callback: None, strict_batch_size: false, file_reader_options, + aggregate: None, legacy_with_row_addr: false, legacy_with_row_id: false, explicit_projection: false, @@ -1017,6 +1079,28 @@ impl Scanner { self } + /// Set aggregation using Substrait Plan bytes containing an AggregateRel. + #[cfg(feature = "substrait")] + pub fn aggregate_substrait(&mut self, aggregate_rel: &[u8]) -> Result<&mut Self> { + self.aggregate = Some(AggregateSpec::Substrait(aggregate_rel.to_vec())); + Ok(self) + } + + /// Set aggregation using DataFusion expressions. + pub fn aggregate_expr( + &mut self, + group_by: Vec, + aggregates: Vec, + output_names: Vec, + ) -> Result<&mut Self> { + self.aggregate = Some(AggregateSpec::Datafusion { + group_by, + aggregates, + output_names, + }); + Ok(self) + } + /// Set the batch size. pub fn batch_size(&mut self, batch_size: usize) -> &mut Self { self.batch_size = Some(batch_size); @@ -1767,6 +1851,149 @@ impl Scanner { .boxed() } + /// Create an execution plan with aggregation. + /// + /// Requires `aggregate_substrait()` or `aggregate_expr()` to be called first. + #[cfg(feature = "substrait")] + pub fn create_aggregate_plan(&self) -> BoxFuture<'_, Result>> { + use datafusion_physical_expr::aggregate::AggregateFunctionExpr; + + async move { + let agg_spec = self.aggregate.as_ref().ok_or_else(|| { + Error::invalid_input( + "create_aggregate_plan called but no aggregate was set", + location!(), + ) + })?; + + let plan = self.create_plan().await?; + let schema = plan.schema(); + let agg = agg_spec.to_datafusion(schema.clone())?; + let df_schema = DFSchema::try_from(schema.as_ref().clone())?; + let num_groups = agg.group_by.len(); + + let group_exprs: Vec<(Arc, String)> = agg + .group_by + .iter() + .enumerate() + .map(|(i, expr)| { + let name = if i < agg.output_names.len() { + agg.output_names[i].clone() + } else { + expr.schema_name().to_string() + }; + let physical_expr = + create_physical_expr(expr, &df_schema, &ExecutionProps::default())?; + Ok((physical_expr, name)) + }) + .collect::>()?; + + let aggr_exprs: Vec> = agg + .aggregates + .iter() + .enumerate() + .map(|(i, expr)| { + let output_name_idx = num_groups + i; + let alias = if output_name_idx < agg.output_names.len() { + Some(agg.output_names[output_name_idx].as_str()) + } else { + None + }; + self.build_physical_aggregate_expr_with_alias(expr, &df_schema, &schema, alias) + }) + .collect::>()?; + + let filters: Vec>> = vec![None; aggr_exprs.len()]; + + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(group_exprs), + aggr_exprs, + filters, + plan, + schema, + )?) as Arc) + } + .boxed() + } + + #[cfg(feature = "substrait")] + fn build_physical_aggregate_expr_with_alias( + &self, + expr: &Expr, + df_schema: &DFSchema, + input_schema: &SchemaRef, + alias: Option<&str>, + ) -> Result> { + use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter; + + let coerced_expr = self.coerce_aggregate_expr(expr, df_schema)?; + let aliased_expr = if let Some(name) = alias { + coerced_expr.alias(name) + } else { + coerced_expr + }; + + let (agg_expr, _filter, _order_by) = create_aggregate_expr_and_maybe_filter( + &aliased_expr, + df_schema, + input_schema.as_ref(), + &ExecutionProps::default(), + )?; + + Ok(agg_expr) + } + + /// Apply type coercion to aggregate arguments for UserDefined signature functions. + #[cfg(feature = "substrait")] + fn coerce_aggregate_expr(&self, expr: &Expr, schema: &DFSchema) -> Result { + use datafusion::logical_expr::{expr::AggregateFunction, Expr, TypeSignature}; + + match expr { + Expr::AggregateFunction(agg_func) => { + let func = &agg_func.func; + let args = &agg_func.params.args; + + if !matches!(func.signature().type_signature, TypeSignature::UserDefined) { + return Ok(expr.clone()); + } + + if args.is_empty() { + return Ok(expr.clone()); + } + + let current_types: Vec = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>()?; + + let coerced_types = func.coerce_types(¤t_types)?; + let coerced_args: Vec = args + .iter() + .zip(coerced_types.iter()) + .map(|(arg, target_type)| { + let arg_type = arg.get_type(schema)?; + if arg_type == *target_type { + Ok(arg.clone()) + } else { + arg.clone().cast_to(target_type, schema) + } + }) + .collect::>()?; + + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func.clone(), + coerced_args, + agg_func.params.distinct, + agg_func.params.filter.clone(), + agg_func.params.order_by.clone(), + agg_func.params.null_treatment, + ))) + } + other => Ok(other.clone()), + } + } + // A "narrow" field is a field that is so small that we are better off reading the // entire column and filtering in memory rather than "take"ing the column. // diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs new file mode 100644 index 00000000000..4135a3b7ee4 --- /dev/null +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -0,0 +1,776 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Tests for aggregate pushdown via Substrait + +use std::sync::Arc; + +use arrow_array::cast::AsArray; +use arrow_array::types::{Float64Type, Int64Type}; +use arrow_array::{RecordBatch, RecordBatchIterator}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use datafusion_substrait::substrait::proto::{ + aggregate_function::AggregationInvocation, + aggregate_rel::{Grouping, Measure}, + expression::{ + field_reference::{ReferenceType, RootReference, RootType}, + reference_segment::{self, StructField}, + FieldReference, ReferenceSegment, RexType, + }, + extensions::{ + simple_extension_declaration::{ExtensionFunction, MappingType}, + SimpleExtensionDeclaration, SimpleExtensionUri, + }, + function_argument::ArgType, + rel::RelType, + AggregateFunction, AggregateRel, Expression, FunctionArgument, Plan, PlanRel, Rel, RelRoot, + Version, +}; +use futures::TryStreamExt; +use lance_datafusion::exec::{execute_plan, LanceExecutionOptions}; +use lance_datagen::{array, gen_batch}; +use lance_table::format::Fragment; +use prost::Message; +use tempfile::tempdir; + +use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; +use crate::Dataset; + +/// Helper to create a field reference expression for a column index +fn field_ref(field_index: i32) -> Expression { + Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField(Box::new( + StructField { + field: field_index, + child: None, + }, + ))), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }))), + } +} + +/// Helper to create a Substrait AggregateRel with given measures and groupings +fn create_aggregate_rel( + measures: Vec, + grouping_expressions: Vec, + groupings: Vec, + extensions: Vec, + output_names: Vec, +) -> Vec { + let aggregate_rel = AggregateRel { + common: None, + input: None, // Input is ignored for pushdown + groupings, + measures, + grouping_expressions, + advanced_extension: None, + }; + + let rel = Rel { + rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))), + }; + + // Wrap in a Plan to include extensions + let plan = Plan { + version: Some(Version { + major_number: 0, + minor_number: 63, + patch_number: 0, + git_hash: String::new(), + producer: "lance-test".to_string(), + }), + #[allow(deprecated)] + extension_uris: vec![ + SimpleExtensionUri { + extension_uri_anchor: 1, + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml".to_string(), + }, + SimpleExtensionUri { + extension_uri_anchor: 2, + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml".to_string(), + }, + ], + extensions, + relations: vec![PlanRel { + rel_type: Some(datafusion_substrait::substrait::proto::plan_rel::RelType::Root( + RelRoot { + input: Some(rel), + names: output_names, + }, + )), + }], + advanced_extensions: None, + expected_type_urls: vec![], + extension_urns: vec![], + parameter_bindings: vec![], + type_aliases: vec![], + }; + + plan.encode_to_vec() +} + +/// Create extension declaration for an aggregate function +fn agg_extension(anchor: u32, name: &str) -> SimpleExtensionDeclaration { + SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction { + #[allow(deprecated)] + extension_uri_reference: 1, + extension_urn_reference: 0, + function_anchor: anchor, + name: name.to_string(), + })), + } +} + +/// Create a COUNT(*) measure +fn count_star_measure(function_ref: u32) -> Measure { + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![], // COUNT(*) has no arguments + options: vec![], + output_type: None, + phase: 0, + sorts: vec![], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } +} + +/// Create a SUM/AVG/MIN/MAX measure on a column +fn simple_agg_measure(function_ref: u32, column_index: i32) -> Measure { + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(field_ref(column_index))), + }], + options: vec![], + output_type: None, + phase: 0, + sorts: vec![], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } +} + +/// Execute aggregate plan and collect results +async fn execute_aggregate( + dataset: &Dataset, + aggregate_bytes: &[u8], +) -> crate::Result> { + let mut scanner = dataset.scan(); + scanner.aggregate_substrait(aggregate_bytes)?; + + let plan = scanner.create_aggregate_plan().await?; + let stream = execute_plan(plan, LanceExecutionOptions::default())?; + stream.try_collect().await.map_err(|e| e.into()) +} + +/// Execute aggregate plan on specific fragments +async fn execute_aggregate_on_fragments( + dataset: &Dataset, + aggregate_bytes: &[u8], + fragments: Vec, +) -> crate::Result> { + let mut scanner = dataset.scan(); + scanner.with_fragments(fragments); + scanner.aggregate_substrait(aggregate_bytes)?; + + let plan = scanner.create_aggregate_plan().await?; + let stream = execute_plan(plan, LanceExecutionOptions::default())?; + stream.try_collect().await.map_err(|e| e.into()) +} + +/// Create a test dataset with numeric columns +async fn create_numeric_dataset(uri: &str, num_fragments: u32, rows_per_fragment: u32) -> Dataset { + gen_batch() + .col("x", array::step::()) + .col("y", array::step_custom::(0, 2)) + .col("category", array::cycle::(vec![1, 2, 3])) + .into_dataset( + uri, + FragmentCount::from(num_fragments), + FragmentRowCount::from(rows_per_fragment), + ) + .await + .unwrap() +} + +// ============================================================================ +// COUNT(*) Tests +// ============================================================================ + +#[tokio::test] +async fn test_count_star_single_fragment() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.column(0).as_primitive::().value(0), 100); +} + +#[tokio::test] +async fn test_count_star_multiple_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 5, 100).await; + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + // 5 fragments * 100 rows = 500 total + assert_eq!(batch.column(0).as_primitive::().value(0), 500); +} + +#[tokio::test] +async fn test_count_star_subset_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 5, 100).await; + + // Get only first 2 fragments + let all_fragments = ds.get_fragments(); + let subset: Vec = all_fragments + .into_iter() + .take(2) + .map(|f| f.metadata) + .collect(); + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset) + .await + .unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + // 2 fragments * 100 rows = 200 total + assert_eq!(batch.column(0).as_primitive::().value(0), 200); +} + +// ============================================================================ +// SUM Tests +// ============================================================================ + +#[tokio::test] +async fn test_sum_single_fragment() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + // SUM(x) where x = 0..99 + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], // column 0 = x + vec![], + vec![], + vec![agg_extension(1, "sum")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + // SUM(0..99) = 99*100/2 = 4950 + assert_eq!(batch.column(0).as_primitive::().value(0), 4950); +} + +#[tokio::test] +async fn test_sum_multiple_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 25).await; + + // SUM(x) where x = 0..99 across 4 fragments + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], + vec![], + vec![], + vec![agg_extension(1, "sum")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + // SUM(0..99) = 4950 + assert_eq!(batch.column(0).as_primitive::().value(0), 4950); +} + +// ============================================================================ +// MIN/MAX Tests +// ============================================================================ + +#[tokio::test] +async fn test_min_max() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 25).await; + + // MIN(x) and MAX(x) + let agg_bytes = create_aggregate_rel( + vec![ + simple_agg_measure(1, 0), // MIN(x) + simple_agg_measure(2, 0), // MAX(x) + ], + vec![], + vec![], + vec![agg_extension(1, "min"), agg_extension(2, "max")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 2); + // MIN should be 0, MAX should be 99 + assert_eq!(batch.column(0).as_primitive::().value(0), 0); + assert_eq!(batch.column(1).as_primitive::().value(0), 99); +} + +// ============================================================================ +// AVG Tests +// ============================================================================ + +#[tokio::test] +async fn test_avg() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 25).await; + + // AVG(x) where x = 0..99 + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], + vec![], + vec![], + vec![agg_extension(1, "avg")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + // AVG(0..99) = 49.5 + let avg = batch.column(0).as_primitive::().value(0); + assert!((avg - 49.5).abs() < 0.001); +} + +// ============================================================================ +// Multiple Aggregates Tests +// ============================================================================ + +#[tokio::test] +async fn test_multiple_aggregates() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 25).await; + + // COUNT(*), SUM(x), MIN(x), MAX(x), AVG(x) + let agg_bytes = create_aggregate_rel( + vec![ + count_star_measure(1), + simple_agg_measure(2, 0), // SUM(x) + simple_agg_measure(3, 0), // MIN(x) + simple_agg_measure(4, 0), // MAX(x) + simple_agg_measure(5, 0), // AVG(x) + ], + vec![], + vec![], + vec![ + agg_extension(1, "count"), + agg_extension(2, "sum"), + agg_extension(3, "min"), + agg_extension(4, "max"), + agg_extension(5, "avg"), + ], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 5); + + // Verify all aggregates + assert_eq!(batch.column(0).as_primitive::().value(0), 100); // COUNT + assert_eq!(batch.column(1).as_primitive::().value(0), 4950); // SUM + assert_eq!(batch.column(2).as_primitive::().value(0), 0); // MIN + assert_eq!(batch.column(3).as_primitive::().value(0), 99); // MAX + let avg = batch.column(4).as_primitive::().value(0); + assert!((avg - 49.5).abs() < 0.001); // AVG +} + +// ============================================================================ +// GROUP BY Tests +// ============================================================================ + +#[tokio::test] +async fn test_group_by_with_count() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 30).await; + + // COUNT(*) GROUP BY category + // category cycles through 1, 2, 3 + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![field_ref(2)], // category is column index 2 + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], // Reference to first grouping expression + }], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(batch.num_rows(), 3); // 3 categories + + // Each category should have 40 rows (120 total / 3 categories) + let counts: Vec = batch + .column(1) // count column + .as_primitive::() + .values() + .to_vec(); + + for count in counts { + assert_eq!(count, 40); + } +} + +#[tokio::test] +async fn test_group_by_with_sum() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 9).await; + + // SUM(x) GROUP BY category + // x = 0..8, category cycles 1,2,3,1,2,3,1,2,3 + // category 1: sum(0,3,6) = 9 + // category 2: sum(1,4,7) = 12 + // category 3: sum(2,5,8) = 15 + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], // SUM(x) + vec![field_ref(2)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "sum")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(batch.num_rows(), 3); // 3 categories + + // Collect results into a map for verification + let categories: Vec = batch + .column(0) // category column + .as_primitive::() + .values() + .to_vec(); + let sums: Vec = batch + .column(1) // sum column + .as_primitive::() + .values() + .to_vec(); + + let mut results_map = std::collections::HashMap::new(); + for (cat, sum) in categories.iter().zip(sums.iter()) { + results_map.insert(*cat, *sum); + } + + assert_eq!(results_map.get(&1), Some(&9)); + assert_eq!(results_map.get(&2), Some(&12)); + assert_eq!(results_map.get(&3), Some(&15)); +} + +// ============================================================================ +// Fragment Subset with Aggregates Tests +// ============================================================================ + +#[tokio::test] +async fn test_aggregate_specific_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 10, 10).await; + + // Get fragments 3, 5, 7 (0-indexed) + let all_fragments = ds.get_fragments(); + let subset: Vec = all_fragments + .into_iter() + .enumerate() + .filter(|(i, _)| *i == 3 || *i == 5 || *i == 7) + .map(|(_, f)| f.metadata) + .collect(); + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset) + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + // 3 fragments * 10 rows = 30 total + assert_eq!(batch.column(0).as_primitive::().value(0), 30); +} + +#[tokio::test] +async fn test_sum_specific_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + // Create dataset where each fragment has distinct values + // Fragment 0: x = 0..9 (sum = 45) + // Fragment 1: x = 10..19 (sum = 145) + // Fragment 2: x = 20..29 (sum = 245) + // Fragment 3: x = 30..39 (sum = 345) + let ds = create_numeric_dataset(uri, 4, 10).await; + + // Only scan fragments 1 and 2 + let all_fragments = ds.get_fragments(); + let subset: Vec = all_fragments + .into_iter() + .enumerate() + .filter(|(i, _)| *i == 1 || *i == 2) + .map(|(_, f)| f.metadata) + .collect(); + + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], // SUM(x) + vec![], + vec![], + vec![agg_extension(1, "sum")], + vec![], + ); + + let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset) + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + // Fragment 1: sum(10..19) = 145 + // Fragment 2: sum(20..29) = 245 + // Total = 390 + assert_eq!(batch.column(0).as_primitive::().value(0), 390); +} + +// ============================================================================ +// Edge Cases +// ============================================================================ + +#[tokio::test] +async fn test_aggregate_empty_result() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + // Apply filter that matches no rows, then aggregate + let mut scanner = ds.scan(); + scanner.project::<&str>(&[]).unwrap(); + scanner.with_row_id(); + scanner.filter("x > 1000").unwrap(); // No rows match + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + scanner.aggregate_substrait(&agg_bytes).unwrap(); + + let plan = scanner.create_aggregate_plan().await.unwrap(); + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let results: Vec = stream.try_collect().await.unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + // COUNT(*) of empty result should be 0 + assert_eq!(batch.column(0).as_primitive::().value(0), 0); +} + +#[tokio::test] +async fn test_aggregate_single_row() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + // Create dataset with single row using Int64 to avoid type coercion issues + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "x", + DataType::Int64, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(arrow_array::Int64Array::from(vec![42]))], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + let ds = Dataset::write(reader, uri, None).await.unwrap(); + + let agg_bytes = create_aggregate_rel( + vec![ + count_star_measure(1), + simple_agg_measure(2, 0), // SUM(x) + simple_agg_measure(3, 0), // MIN(x) + simple_agg_measure(4, 0), // MAX(x) + ], + vec![], + vec![], + vec![ + agg_extension(1, "count"), + agg_extension(2, "sum"), + agg_extension(3, "min"), + agg_extension(4, "max"), + ], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + + assert_eq!(batch.column(0).as_primitive::().value(0), 1); // COUNT + assert_eq!(batch.column(1).as_primitive::().value(0), 42); // SUM + assert_eq!(batch.column(2).as_primitive::().value(0), 42); // MIN + assert_eq!(batch.column(3).as_primitive::().value(0), 42); // MAX +} + +// ============================================================================ +// Output Schema / Alias Tests +// ============================================================================ + +#[tokio::test] +async fn test_aggregate_with_aliases() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + // COUNT(*), SUM(x), MIN(x) with custom aliases + let agg_bytes = create_aggregate_rel( + vec![ + count_star_measure(1), + simple_agg_measure(2, 0), + simple_agg_measure(3, 0), + ], + vec![], + vec![], + vec![ + agg_extension(1, "count"), + agg_extension(2, "sum"), + agg_extension(3, "min"), + ], + vec![ + "total_count".to_string(), + "sum_of_x".to_string(), + "min_x".to_string(), + ], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + + // Verify output schema has the expected aliases + let schema = batch.schema(); + assert_eq!(schema.fields().len(), 3); + assert_eq!(schema.field(0).name(), "total_count"); + assert_eq!(schema.field(1).name(), "sum_of_x"); + assert_eq!(schema.field(2).name(), "min_x"); + + // Verify values are correct + assert_eq!(batch.column(0).as_primitive::().value(0), 100); + assert_eq!(batch.column(1).as_primitive::().value(0), 4950); + assert_eq!(batch.column(2).as_primitive::().value(0), 0); +} + +#[tokio::test] +async fn test_group_by_with_aliases() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 9).await; + + // SUM(x) GROUP BY category with aliases + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], + vec![field_ref(2)], + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "sum")], + vec!["group_key".to_string(), "total_sum".to_string()], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + + // Verify output schema has the expected aliases + let schema = batch.schema(); + assert_eq!(schema.fields().len(), 2); + assert_eq!(schema.field(0).name(), "group_key"); + assert_eq!(schema.field(1).name(), "total_sum"); +} diff --git a/rust/lance/src/dataset/tests/mod.rs b/rust/lance/src/dataset/tests/mod.rs index a1197dcc198..c0a6002debb 100644 --- a/rust/lance/src/dataset/tests/mod.rs +++ b/rust/lance/src/dataset/tests/mod.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +#[cfg(feature = "substrait")] +mod dataset_aggregate; mod dataset_common; mod dataset_concurrency_store; mod dataset_geo; From 8da6281302f62466319b0dea751d84c9edb15521 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Sun, 8 Feb 2026 23:21:18 -0800 Subject: [PATCH 2/9] refactor(lance): rename AggregateSpec to AggregateExpr and consolidate API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename AggregateSpec to AggregateExpr for consistency - Add helper constructors: substrait() and datafusion() - Combine aggregate_substrait() and aggregate_expr() into single aggregate() method - Update tests to use new API 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- rust/lance/src/dataset/scanner.rs | 75 ++++++++----------- .../src/dataset/tests/dataset_aggregate.rs | 9 ++- 2 files changed, 35 insertions(+), 49 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index b217b968462..edb244d1d83 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -462,9 +462,9 @@ impl ExprFilter { } } -/// Aggregate specification from Substrait or DataFusion expressions. +/// Aggregate expression from Substrait or DataFusion. #[derive(Debug, Clone)] -pub enum AggregateSpec { +pub enum AggregateExpr { #[cfg(feature = "substrait")] Substrait(Vec), Datafusion { @@ -474,10 +474,28 @@ pub enum AggregateSpec { }, } -#[cfg(feature = "substrait")] -impl AggregateSpec { - /// Converts to DataFusion expressions. - pub fn to_datafusion( +impl AggregateExpr { + /// Create from Substrait Plan bytes. + #[cfg(feature = "substrait")] + pub fn substrait(bytes: impl Into>) -> Self { + Self::Substrait(bytes.into()) + } + + /// Create from DataFusion expressions. + pub fn datafusion( + group_by: Vec, + aggregates: Vec, + output_names: Vec, + ) -> Self { + Self::Datafusion { + group_by, + aggregates, + output_names, + } + } + + #[cfg(feature = "substrait")] + fn to_aggregate( &self, schema: Arc, ) -> Result { @@ -504,23 +522,6 @@ impl AggregateSpec { } } -#[cfg(not(feature = "substrait"))] -impl AggregateSpec { - /// Converts the aggregate specification to DataFusion expressions - pub fn to_datafusion( - &self, - _schema: Arc, - ) -> Result<(Vec, Vec, Vec)> { - match self { - Self::Datafusion { - group_by, - aggregates, - output_names, - } => Ok((group_by.clone(), aggregates.clone(), output_names.clone())), - } - } -} - /// Dataset Scanner /// /// ```rust,ignore @@ -629,7 +630,7 @@ pub struct Scanner { /// File reader options to use when reading data files. file_reader_options: Option, - aggregate: Option, + aggregate: Option, // Legacy fields to help migrate some old projection behavior to new behavior // @@ -1079,26 +1080,10 @@ impl Scanner { self } - /// Set aggregation using Substrait Plan bytes containing an AggregateRel. - #[cfg(feature = "substrait")] - pub fn aggregate_substrait(&mut self, aggregate_rel: &[u8]) -> Result<&mut Self> { - self.aggregate = Some(AggregateSpec::Substrait(aggregate_rel.to_vec())); - Ok(self) - } - - /// Set aggregation using DataFusion expressions. - pub fn aggregate_expr( - &mut self, - group_by: Vec, - aggregates: Vec, - output_names: Vec, - ) -> Result<&mut Self> { - self.aggregate = Some(AggregateSpec::Datafusion { - group_by, - aggregates, - output_names, - }); - Ok(self) + /// Set aggregation. + pub fn aggregate(&mut self, aggregate: AggregateExpr) -> &mut Self { + self.aggregate = Some(aggregate); + self } /// Set the batch size. @@ -1868,7 +1853,7 @@ impl Scanner { let plan = self.create_plan().await?; let schema = plan.schema(); - let agg = agg_spec.to_datafusion(schema.clone())?; + let agg = agg_spec.to_aggregate(schema.clone())?; let df_schema = DFSchema::try_from(schema.as_ref().clone())?; let num_groups = agg.group_by.len(); diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index 4135a3b7ee4..42ec2ab5eb5 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -//! Tests for aggregate pushdown via Substrait +//! Tests for Substrait aggregate use std::sync::Arc; @@ -33,6 +33,7 @@ use lance_table::format::Fragment; use prost::Message; use tempfile::tempdir; +use crate::dataset::scanner::AggregateExpr; use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; use crate::Dataset; @@ -170,7 +171,7 @@ async fn execute_aggregate( aggregate_bytes: &[u8], ) -> crate::Result> { let mut scanner = dataset.scan(); - scanner.aggregate_substrait(aggregate_bytes)?; + scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); let plan = scanner.create_aggregate_plan().await?; let stream = execute_plan(plan, LanceExecutionOptions::default())?; @@ -185,7 +186,7 @@ async fn execute_aggregate_on_fragments( ) -> crate::Result> { let mut scanner = dataset.scan(); scanner.with_fragments(fragments); - scanner.aggregate_substrait(aggregate_bytes)?; + scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); let plan = scanner.create_aggregate_plan().await?; let stream = execute_plan(plan, LanceExecutionOptions::default())?; @@ -634,7 +635,7 @@ async fn test_aggregate_empty_result() { vec![agg_extension(1, "count")], vec![], ); - scanner.aggregate_substrait(&agg_bytes).unwrap(); + scanner.aggregate(AggregateExpr::substrait(agg_bytes)); let plan = scanner.create_aggregate_plan().await.unwrap(); let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); From 74d0aaca586b30b173e813829edfda3adbc5e331 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Sun, 8 Feb 2026 23:34:43 -0800 Subject: [PATCH 3/9] chore: remove section separator comments from aggregate tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../src/dataset/tests/dataset_aggregate.rs | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index 42ec2ab5eb5..f666342b63b 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -208,10 +208,6 @@ async fn create_numeric_dataset(uri: &str, num_fragments: u32, rows_per_fragment .unwrap() } -// ============================================================================ -// COUNT(*) Tests -// ============================================================================ - #[tokio::test] async fn test_count_star_single_fragment() { let tmp_dir = tempdir().unwrap(); @@ -287,10 +283,6 @@ async fn test_count_star_subset_fragments() { assert_eq!(batch.column(0).as_primitive::().value(0), 200); } -// ============================================================================ -// SUM Tests -// ============================================================================ - #[tokio::test] async fn test_sum_single_fragment() { let tmp_dir = tempdir().unwrap(); @@ -336,10 +328,6 @@ async fn test_sum_multiple_fragments() { assert_eq!(batch.column(0).as_primitive::().value(0), 4950); } -// ============================================================================ -// MIN/MAX Tests -// ============================================================================ - #[tokio::test] async fn test_min_max() { let tmp_dir = tempdir().unwrap(); @@ -368,10 +356,6 @@ async fn test_min_max() { assert_eq!(batch.column(1).as_primitive::().value(0), 99); } -// ============================================================================ -// AVG Tests -// ============================================================================ - #[tokio::test] async fn test_avg() { let tmp_dir = tempdir().unwrap(); @@ -395,10 +379,6 @@ async fn test_avg() { assert!((avg - 49.5).abs() < 0.001); } -// ============================================================================ -// Multiple Aggregates Tests -// ============================================================================ - #[tokio::test] async fn test_multiple_aggregates() { let tmp_dir = tempdir().unwrap(); @@ -441,10 +421,6 @@ async fn test_multiple_aggregates() { assert!((avg - 49.5).abs() < 0.001); // AVG } -// ============================================================================ -// GROUP BY Tests -// ============================================================================ - #[tokio::test] async fn test_group_by_with_count() { let tmp_dir = tempdir().unwrap(); @@ -534,10 +510,6 @@ async fn test_group_by_with_sum() { assert_eq!(results_map.get(&3), Some(&15)); } -// ============================================================================ -// Fragment Subset with Aggregates Tests -// ============================================================================ - #[tokio::test] async fn test_aggregate_specific_fragments() { let tmp_dir = tempdir().unwrap(); @@ -612,10 +584,6 @@ async fn test_sum_specific_fragments() { assert_eq!(batch.column(0).as_primitive::().value(0), 390); } -// ============================================================================ -// Edge Cases -// ============================================================================ - #[tokio::test] async fn test_aggregate_empty_result() { let tmp_dir = tempdir().unwrap(); @@ -697,10 +665,6 @@ async fn test_aggregate_single_row() { assert_eq!(batch.column(3).as_primitive::().value(0), 42); // MAX } -// ============================================================================ -// Output Schema / Alias Tests -// ============================================================================ - #[tokio::test] async fn test_aggregate_with_aliases() { let tmp_dir = tempdir().unwrap(); From 81a0841b02ebcf7a08eaea9445aea1079fbd7133 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Sun, 8 Feb 2026 23:35:54 -0800 Subject: [PATCH 4/9] test(lance): add test for aggregate with filter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../src/dataset/tests/dataset_aggregate.rs | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index f666342b63b..e7647c573e6 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -584,6 +584,50 @@ async fn test_sum_specific_fragments() { assert_eq!(batch.column(0).as_primitive::().value(0), 390); } +#[tokio::test] +async fn test_aggregate_with_filter() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + let mut scanner = ds.scan(); + scanner.filter("x >= 50").unwrap(); + + let agg_bytes = create_aggregate_rel( + vec![ + count_star_measure(1), + simple_agg_measure(2, 0), // SUM(x) + simple_agg_measure(3, 0), // MIN(x) + simple_agg_measure(4, 0), // MAX(x) + ], + vec![], + vec![], + vec![ + agg_extension(1, "count"), + agg_extension(2, "sum"), + agg_extension(3, "min"), + agg_extension(4, "max"), + ], + vec![], + ); + scanner.aggregate(AggregateExpr::substrait(agg_bytes)); + + let plan = scanner.create_aggregate_plan().await.unwrap(); + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let results: Vec = stream.try_collect().await.unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + + // Filter x >= 50 matches rows 50..99 (50 rows) + assert_eq!(batch.column(0).as_primitive::().value(0), 50); // COUNT + // SUM(50..99) = (50+99)*50/2 = 3725 + assert_eq!(batch.column(1).as_primitive::().value(0), 3725); // SUM + assert_eq!(batch.column(2).as_primitive::().value(0), 50); // MIN + assert_eq!(batch.column(3).as_primitive::().value(0), 99); // MAX +} + #[tokio::test] async fn test_aggregate_empty_result() { let tmp_dir = tempdir().unwrap(); From fa6572591958caffe1b8987f310c128e5dcb1c1f Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Sun, 8 Feb 2026 23:45:26 -0800 Subject: [PATCH 5/9] refactor: move Aggregate struct to non-feature-gated module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create lance-datafusion/src/aggregate.rs for Aggregate struct - Remove #[cfg(feature = "substrait")] from create_aggregate_plan - create_aggregate_plan now works without substrait feature 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- rust/lance-datafusion/src/aggregate.rs | 15 +++++++++++++++ rust/lance-datafusion/src/lib.rs | 1 + rust/lance-datafusion/src/substrait.rs | 11 ++--------- rust/lance/src/dataset/scanner.rs | 18 ++++++++---------- .../src/dataset/tests/dataset_aggregate.rs | 2 +- 5 files changed, 27 insertions(+), 20 deletions(-) create mode 100644 rust/lance-datafusion/src/aggregate.rs diff --git a/rust/lance-datafusion/src/aggregate.rs b/rust/lance-datafusion/src/aggregate.rs new file mode 100644 index 00000000000..4c718d31ec7 --- /dev/null +++ b/rust/lance-datafusion/src/aggregate.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Aggregate specification for DataFusion aggregates. + +use datafusion::logical_expr::Expr; + +/// Aggregate specification with group by and aggregate expressions. +#[derive(Debug, Clone)] +pub struct Aggregate { + pub group_by: Vec, + pub aggregates: Vec, + /// Output column names in order: group_by columns first, then aggregates. + pub output_names: Vec, +} diff --git a/rust/lance-datafusion/src/lib.rs b/rust/lance-datafusion/src/lib.rs index fa65a918191..0ef51216ae8 100644 --- a/rust/lance-datafusion/src/lib.rs +++ b/rust/lance-datafusion/src/lib.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +pub mod aggregate; pub mod chunker; pub mod dataframe; pub mod datagen; diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index 09630c396a6..d45df030a90 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -3,6 +3,8 @@ use arrow_schema::Schema as ArrowSchema; use datafusion::{execution::SessionState, logical_expr::Expr}; + +use crate::aggregate::Aggregate; use datafusion_common::DFSchema; use datafusion_substrait::extensions::Extensions; use datafusion_substrait::logical_plan::consumer::{ @@ -330,15 +332,6 @@ pub async fn parse_substrait( Ok(expr_container.exprs.pop().unwrap().0) } -/// Aggregate specification with group by and aggregate expressions. -#[derive(Debug, Clone)] -pub struct Aggregate { - pub group_by: Vec, - pub aggregates: Vec, - /// Output column names in order: group_by columns first, then aggregates. - pub output_names: Vec, -} - /// Parse Substrait Plan bytes containing an AggregateRel. pub async fn parse_substrait_aggregate( bytes: &[u8], diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index edb244d1d83..bd5f328d57f 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -53,6 +53,7 @@ use lance_core::utils::address::RowAddress; use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap}; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{ROW_ADDR, ROW_ID, ROW_OFFSET}; +use lance_datafusion::aggregate::Aggregate; use lance_datafusion::exec::{ analyze_plan, execute_plan, LanceExecutionOptions, OneShotExec, StrictBatchSizeExec, }; @@ -494,16 +495,16 @@ impl AggregateExpr { } } - #[cfg(feature = "substrait")] fn to_aggregate( &self, - schema: Arc, - ) -> Result { - use lance_datafusion::exec::{get_session_context, LanceExecutionOptions}; - use lance_datafusion::substrait::{parse_substrait_aggregate, Aggregate}; - + #[allow(unused_variables)] schema: Arc, + ) -> Result { match self { + #[cfg(feature = "substrait")] Self::Substrait(bytes) => { + use lance_datafusion::exec::{get_session_context, LanceExecutionOptions}; + use lance_datafusion::substrait::parse_substrait_aggregate; + let ctx = get_session_context(&LanceExecutionOptions::default()); parse_substrait_aggregate(bytes, schema, &ctx.state()) .now_or_never() @@ -1838,8 +1839,7 @@ impl Scanner { /// Create an execution plan with aggregation. /// - /// Requires `aggregate_substrait()` or `aggregate_expr()` to be called first. - #[cfg(feature = "substrait")] + /// Requires `aggregate()` to be called first. pub fn create_aggregate_plan(&self) -> BoxFuture<'_, Result>> { use datafusion_physical_expr::aggregate::AggregateFunctionExpr; @@ -1902,7 +1902,6 @@ impl Scanner { .boxed() } - #[cfg(feature = "substrait")] fn build_physical_aggregate_expr_with_alias( &self, expr: &Expr, @@ -1930,7 +1929,6 @@ impl Scanner { } /// Apply type coercion to aggregate arguments for UserDefined signature functions. - #[cfg(feature = "substrait")] fn coerce_aggregate_expr(&self, expr: &Expr, schema: &DFSchema) -> Result { use datafusion::logical_expr::{expr::AggregateFunction, Expr, TypeSignature}; diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index e7647c573e6..16fa7c126c5 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -622,7 +622,7 @@ async fn test_aggregate_with_filter() { // Filter x >= 50 matches rows 50..99 (50 rows) assert_eq!(batch.column(0).as_primitive::().value(0), 50); // COUNT - // SUM(50..99) = (50+99)*50/2 = 3725 + // SUM(50..99) = (50+99)*50/2 = 3725 assert_eq!(batch.column(1).as_primitive::().value(0), 3725); // SUM assert_eq!(batch.column(2).as_primitive::().value(0), 50); // MIN assert_eq!(batch.column(3).as_primitive::().value(0), 99); // MAX From f720e9917dfb0c885755d617905eb3f2a0f2edd2 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Sun, 8 Feb 2026 23:48:40 -0800 Subject: [PATCH 6/9] feat: add sorts support for aggregate functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Parse agg_func.sorts for ordered aggregates like ARRAY_AGG. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- rust/lance-datafusion/src/substrait.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index d45df030a90..8c172ba4e59 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -8,7 +8,7 @@ use crate::aggregate::Aggregate; use datafusion_common::DFSchema; use datafusion_substrait::extensions::Extensions; use datafusion_substrait::logical_plan::consumer::{ - from_substrait_agg_func, from_substrait_rex, DefaultSubstraitConsumer, + from_substrait_agg_func, from_substrait_rex, from_substrait_sorts, DefaultSubstraitConsumer, }; use datafusion_substrait::substrait::proto::{ expression::{ @@ -495,7 +495,14 @@ async fn parse_measures( }; // Parse ordering (for ordered aggregates like ARRAY_AGG) - let order_by = Vec::new(); // TODO: parse agg_func.sorts if needed + let order_by = from_substrait_sorts(consumer, &agg_func.sorts, schema) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse aggregate sorts: {}", e), + location!(), + ) + })?; // Check for DISTINCT invocation let distinct = matches!( From c64a1b65f84386997651ec1cee63539846c608f2 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Sun, 8 Feb 2026 23:51:42 -0800 Subject: [PATCH 7/9] test(lance): add tests for ordered aggregates with sorts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add tests for FIRST_VALUE with ORDER BY ASC and DESC to verify the sorts parsing works correctly. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../src/dataset/tests/dataset_aggregate.rs | 138 +++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index 16fa7c126c5..03afd44c94f 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -23,8 +23,9 @@ use datafusion_substrait::substrait::proto::{ }, function_argument::ArgType, rel::RelType, + sort_field::SortKind, AggregateFunction, AggregateRel, Expression, FunctionArgument, Plan, PlanRel, Rel, RelRoot, - Version, + SortField, Version, }; use futures::TryStreamExt; use lance_datafusion::exec::{execute_plan, LanceExecutionOptions}; @@ -165,6 +166,42 @@ fn simple_agg_measure(function_ref: u32, column_index: i32) -> Measure { } } +/// Create an ordered aggregate measure (e.g., FIRST_VALUE with ORDER BY) +fn ordered_agg_measure( + function_ref: u32, + column_index: i32, + sort_column_index: i32, + ascending: bool, +) -> Measure { + use datafusion_substrait::substrait::proto::sort_field::SortDirection; + + let sort_direction = if ascending { + SortDirection::AscNullsLast + } else { + SortDirection::DescNullsLast + }; + + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(field_ref(column_index))), + }], + options: vec![], + output_type: None, + phase: 0, + sorts: vec![SortField { + expr: Some(field_ref(sort_column_index)), + sort_kind: Some(SortKind::Direction(sort_direction as i32)), + }], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } +} + /// Execute aggregate plan and collect results async fn execute_aggregate( dataset: &Dataset, @@ -783,3 +820,102 @@ async fn test_group_by_with_aliases() { assert_eq!(schema.field(0).name(), "group_key"); assert_eq!(schema.field(1).name(), "total_sum"); } + +#[tokio::test] +async fn test_first_value_with_order_by() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 9).await; + + // FIRST_VALUE(x) ORDER BY x ASC GROUP BY category + // x = 0..8, category cycles 1,2,3,1,2,3,1,2,3 + // category 1 has x values: 0, 3, 6 -> first_value(ORDER BY x ASC) = 0 + // category 2 has x values: 1, 4, 7 -> first_value(ORDER BY x ASC) = 1 + // category 3 has x values: 2, 5, 8 -> first_value(ORDER BY x ASC) = 2 + let agg_bytes = create_aggregate_rel( + vec![ordered_agg_measure(1, 0, 0, true)], // FIRST_VALUE(x) ORDER BY x ASC + vec![field_ref(2)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "first_value")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(batch.num_rows(), 3); + + let categories: Vec = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + let first_values: Vec = batch + .column(1) + .as_primitive::() + .values() + .to_vec(); + + let mut results_map = std::collections::HashMap::new(); + for (cat, val) in categories.iter().zip(first_values.iter()) { + results_map.insert(*cat, *val); + } + + assert_eq!(results_map.get(&1), Some(&0)); + assert_eq!(results_map.get(&2), Some(&1)); + assert_eq!(results_map.get(&3), Some(&2)); +} + +#[tokio::test] +async fn test_first_value_with_order_by_desc() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 9).await; + + // FIRST_VALUE(x) ORDER BY x DESC GROUP BY category + // category 1 has x values: 0, 3, 6 -> first_value(ORDER BY x DESC) = 6 + // category 2 has x values: 1, 4, 7 -> first_value(ORDER BY x DESC) = 7 + // category 3 has x values: 2, 5, 8 -> first_value(ORDER BY x DESC) = 8 + let agg_bytes = create_aggregate_rel( + vec![ordered_agg_measure(1, 0, 0, false)], // FIRST_VALUE(x) ORDER BY x DESC + vec![field_ref(2)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "first_value")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(batch.num_rows(), 3); + + let categories: Vec = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + let first_values: Vec = batch + .column(1) + .as_primitive::() + .values() + .to_vec(); + + let mut results_map = std::collections::HashMap::new(); + for (cat, val) in categories.iter().zip(first_values.iter()) { + results_map.insert(*cat, *val); + } + + assert_eq!(results_map.get(&1), Some(&6)); + assert_eq!(results_map.get(&2), Some(&7)); + assert_eq!(results_map.get(&3), Some(&8)); +} From f9c5a1546428fe252d725e8f607185e06fb1c74a Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Tue, 10 Feb 2026 23:15:05 -0800 Subject: [PATCH 8/9] address comments --- rust/lance-datafusion/src/aggregate.rs | 5 +- rust/lance-datafusion/src/substrait.rs | 19 +- rust/lance/src/dataset/scanner.rs | 352 ++++++++++++++---- .../src/dataset/tests/dataset_aggregate.rs | 8 +- 4 files changed, 301 insertions(+), 83 deletions(-) diff --git a/rust/lance-datafusion/src/aggregate.rs b/rust/lance-datafusion/src/aggregate.rs index 4c718d31ec7..5528104c044 100644 --- a/rust/lance-datafusion/src/aggregate.rs +++ b/rust/lance-datafusion/src/aggregate.rs @@ -8,8 +8,9 @@ use datafusion::logical_expr::Expr; /// Aggregate specification with group by and aggregate expressions. #[derive(Debug, Clone)] pub struct Aggregate { + /// Expressions to group by (e.g., column references). pub group_by: Vec, + /// Aggregate function expressions (e.g., SUM, COUNT, AVG). + /// Use `.alias()` on the expression to set output column names. pub aggregates: Vec, - /// Output column names in order: group_by columns first, then aggregates. - pub output_names: Vec, } diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index 8c172ba4e59..2f84a266f65 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -345,7 +345,23 @@ pub async fn parse_substrait_aggregate( let mut agg = parse_aggregate_rel_with_extensions(&aggregate_rel, input_schema, state, &extensions) .await?; - agg.output_names = output_names; + + // Apply aliases from RelRoot.names to expressions + if !output_names.is_empty() { + let num_groups = agg.group_by.len(); + for (i, expr) in agg.group_by.iter_mut().enumerate() { + if i < output_names.len() { + *expr = expr.clone().alias(&output_names[i]); + } + } + for (i, expr) in agg.aggregates.iter_mut().enumerate() { + let name_idx = num_groups + i; + if name_idx < output_names.len() { + *expr = expr.clone().alias(&output_names[name_idx]); + } + } + } + Ok(agg) } @@ -401,7 +417,6 @@ pub async fn parse_aggregate_rel_with_extensions( Ok(Aggregate { group_by, aggregates, - output_names: vec![], }) } diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index bd5f328d57f..e4cddd542e4 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -34,7 +34,7 @@ use datafusion::scalar::ScalarValue; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::ExprSchemable; use datafusion_functions::core::getfield::GetFieldFunc; -use datafusion_physical_expr::{aggregate::AggregateExprBuilder, expressions::Column}; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{create_physical_expr, LexOrdering, Partitioning, PhysicalExpr}; use datafusion_physical_plan::joins::PartitionMode; use datafusion_physical_plan::projection::ProjectionExec; @@ -471,11 +471,26 @@ pub enum AggregateExpr { Datafusion { group_by: Vec, aggregates: Vec, - output_names: Vec, }, } impl AggregateExpr { + /// Create a new builder for aggregate expressions. + /// + /// # Example + /// ```ignore + /// let agg = AggregateExpr::builder() + /// .group_by("category") + /// .count_star().alias("total_count") + /// .sum("amount").alias("total_amount") + /// .avg("price") + /// .build(); + /// scanner.aggregate(agg); + /// ``` + pub fn builder() -> AggregateExprBuilder { + AggregateExprBuilder::new() + } + /// Create from Substrait Plan bytes. #[cfg(feature = "substrait")] pub fn substrait(bytes: impl Into>) -> Self { @@ -483,15 +498,11 @@ impl AggregateExpr { } /// Create from DataFusion expressions. - pub fn datafusion( - group_by: Vec, - aggregates: Vec, - output_names: Vec, - ) -> Self { + /// Use `.alias()` on expressions to set output column names. + pub fn datafusion(group_by: Vec, aggregates: Vec) -> Self { Self::Datafusion { group_by, aggregates, - output_names, } } @@ -513,16 +524,167 @@ impl AggregateExpr { Self::Datafusion { group_by, aggregates, - output_names, } => Ok(Aggregate { group_by: group_by.clone(), aggregates: aggregates.clone(), - output_names: output_names.clone(), }), } } } +/// Builder for creating aggregate expressions without using DataFusion or Substrait directly. +#[derive(Debug, Clone, Default)] +pub struct AggregateExprBuilder { + group_by: Vec, + aggregates: Vec, +} + +impl AggregateExprBuilder { + /// Create a new builder. + pub fn new() -> Self { + Self::default() + } + + /// Add a column to group by. + pub fn group_by(mut self, column: impl Into) -> Self { + self.group_by.push(col(column.into())); + self + } + + /// Add multiple columns to group by. + pub fn group_by_columns( + mut self, + columns: impl IntoIterator>, + ) -> Self { + for column in columns { + self.group_by.push(col(column.into())); + } + self + } + + /// Add COUNT(*) aggregate. + pub fn count_star(self) -> AggregateExprBuilderWithPendingAggregate { + AggregateExprBuilderWithPendingAggregate { + builder: self, + pending: functions_aggregate::count::count(lit(1)), + } + } + + /// Add COUNT(column) aggregate. + pub fn count(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { + AggregateExprBuilderWithPendingAggregate { + builder: self, + pending: functions_aggregate::count::count(col(column.into())), + } + } + + /// Add SUM(column) aggregate. + pub fn sum(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { + AggregateExprBuilderWithPendingAggregate { + builder: self, + pending: functions_aggregate::sum::sum(col(column.into())), + } + } + + /// Add AVG(column) aggregate. + pub fn avg(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { + AggregateExprBuilderWithPendingAggregate { + builder: self, + pending: functions_aggregate::average::avg(col(column.into())), + } + } + + /// Add MIN(column) aggregate. + pub fn min(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { + AggregateExprBuilderWithPendingAggregate { + builder: self, + pending: functions_aggregate::min_max::min(col(column.into())), + } + } + + /// Add MAX(column) aggregate. + pub fn max(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { + AggregateExprBuilderWithPendingAggregate { + builder: self, + pending: functions_aggregate::min_max::max(col(column.into())), + } + } + + /// Build the aggregate expression. + pub fn build(self) -> AggregateExpr { + AggregateExpr::Datafusion { + group_by: self.group_by, + aggregates: self.aggregates, + } + } +} + +/// Builder state with a pending aggregate that can be aliased. +#[derive(Debug, Clone)] +pub struct AggregateExprBuilderWithPendingAggregate { + builder: AggregateExprBuilder, + pending: Expr, +} + +impl AggregateExprBuilderWithPendingAggregate { + /// Set an alias for the pending aggregate. + pub fn alias(mut self, name: impl Into) -> AggregateExprBuilder { + self.builder + .aggregates + .push(self.pending.alias(name.into())); + self.builder + } + + /// Add another group by column. + pub fn group_by(mut self, column: impl Into) -> AggregateExprBuilder { + self.builder.aggregates.push(self.pending); + self.builder.group_by.push(col(column.into())); + self.builder + } + + /// Add COUNT(*) aggregate. + pub fn count_star(mut self) -> Self { + self.builder.aggregates.push(self.pending); + self.builder.count_star() + } + + /// Add COUNT(column) aggregate. + pub fn count(mut self, column: impl Into) -> Self { + self.builder.aggregates.push(self.pending); + self.builder.count(column) + } + + /// Add SUM(column) aggregate. + pub fn sum(mut self, column: impl Into) -> Self { + self.builder.aggregates.push(self.pending); + self.builder.sum(column) + } + + /// Add AVG(column) aggregate. + pub fn avg(mut self, column: impl Into) -> Self { + self.builder.aggregates.push(self.pending); + self.builder.avg(column) + } + + /// Add MIN(column) aggregate. + pub fn min(mut self, column: impl Into) -> Self { + self.builder.aggregates.push(self.pending); + self.builder.min(column) + } + + /// Add MAX(column) aggregate. + pub fn max(mut self, column: impl Into) -> Self { + self.builder.aggregates.push(self.pending); + self.builder.max(column) + } + + /// Build the aggregate expression. + pub fn build(mut self) -> AggregateExpr { + self.builder.aggregates.push(self.pending); + self.builder.build() + } +} + /// Dataset Scanner /// /// ```rust,ignore @@ -1788,7 +1950,10 @@ impl Scanner { let input_phy_exprs: &[Arc] = &[one]; let schema = plan.schema(); - let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); + let mut builder = datafusion_physical_expr::aggregate::AggregateExprBuilder::new( + count_udaf(), + input_phy_exprs.to_vec(), + ); builder = builder.schema(schema); builder = builder.alias("count_rows".to_string()); @@ -1840,96 +2005,99 @@ impl Scanner { /// Create an execution plan with aggregation. /// /// Requires `aggregate()` to be called first. + #[deprecated(note = "Use create_plan() instead, which now applies aggregate automatically")] pub fn create_aggregate_plan(&self) -> BoxFuture<'_, Result>> { - use datafusion_physical_expr::aggregate::AggregateFunctionExpr; - async move { - let agg_spec = self.aggregate.as_ref().ok_or_else(|| { - Error::invalid_input( + if self.aggregate.is_none() { + return Err(Error::invalid_input( "create_aggregate_plan called but no aggregate was set", location!(), - ) - })?; + )); + } + // create_plan() now applies aggregate automatically when set + self.create_plan().await + } + .boxed() + } - let plan = self.create_plan().await?; - let schema = plan.schema(); - let agg = agg_spec.to_aggregate(schema.clone())?; - let df_schema = DFSchema::try_from(schema.as_ref().clone())?; - let num_groups = agg.group_by.len(); + async fn apply_aggregate( + &self, + plan: Arc, + agg_spec: &AggregateExpr, + ) -> Result> { + use datafusion_physical_expr::aggregate::AggregateFunctionExpr; - let group_exprs: Vec<(Arc, String)> = agg - .group_by - .iter() - .enumerate() - .map(|(i, expr)| { - let name = if i < agg.output_names.len() { - agg.output_names[i].clone() - } else { - expr.schema_name().to_string() - }; - let physical_expr = - create_physical_expr(expr, &df_schema, &ExecutionProps::default())?; - Ok((physical_expr, name)) - }) - .collect::>()?; + let schema = plan.schema(); + let agg = agg_spec.to_aggregate(schema.clone())?; + let df_schema = DFSchema::try_from(schema.as_ref().clone())?; - let aggr_exprs: Vec> = agg - .aggregates - .iter() - .enumerate() - .map(|(i, expr)| { - let output_name_idx = num_groups + i; - let alias = if output_name_idx < agg.output_names.len() { - Some(agg.output_names[output_name_idx].as_str()) - } else { - None - }; - self.build_physical_aggregate_expr_with_alias(expr, &df_schema, &schema, alias) - }) - .collect::>()?; + let group_exprs: Vec<(Arc, String)> = agg + .group_by + .iter() + .map(|expr| { + let name = expr.schema_name().to_string(); + let physical_expr = + create_physical_expr(expr, &df_schema, &ExecutionProps::default())?; + Ok((physical_expr, name)) + }) + .collect::>()?; - let filters: Vec>> = vec![None; aggr_exprs.len()]; + #[allow(clippy::type_complexity)] + let aggr_results: Vec<(Arc, Option>)> = agg + .aggregates + .iter() + .map(|expr| self.build_physical_aggregate_expr(expr, &df_schema, &schema)) + .collect::>()?; - Ok(Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(group_exprs), - aggr_exprs, - filters, - plan, - schema, - )?) as Arc) - } - .boxed() + let (aggr_exprs, filters): (Vec<_>, Vec<_>) = aggr_results.into_iter().unzip(); + + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(group_exprs), + aggr_exprs, + filters, + plan, + schema, + )?) as Arc) } - fn build_physical_aggregate_expr_with_alias( + #[allow(clippy::type_complexity)] + fn build_physical_aggregate_expr( &self, expr: &Expr, df_schema: &DFSchema, input_schema: &SchemaRef, - alias: Option<&str>, - ) -> Result> { + ) -> Result<( + Arc, + Option>, + )> { use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter; let coerced_expr = self.coerce_aggregate_expr(expr, df_schema)?; - let aliased_expr = if let Some(name) = alias { - coerced_expr.alias(name) - } else { - coerced_expr - }; - let (agg_expr, _filter, _order_by) = create_aggregate_expr_and_maybe_filter( - &aliased_expr, + // Note: order_by is already embedded in the AggregateFunctionExpr for ordered aggregates + let (agg_expr, filter, _order_by) = create_aggregate_expr_and_maybe_filter( + &coerced_expr, df_schema, input_schema.as_ref(), &ExecutionProps::default(), )?; - Ok(agg_expr) + Ok((agg_expr, filter)) } /// Apply type coercion to aggregate arguments for UserDefined signature functions. + /// + /// Most aggregate functions (SUM, COUNT, MIN, MAX) have explicit type signatures that + /// DataFusion handles automatically. However, some functions like AVG use UserDefined + /// type signatures in the Substrait consumer, which means DataFusion doesn't know the + /// expected input types and won't perform automatic coercion. We must explicitly coerce + /// arguments to the types returned by `func.coerce_types()`. fn coerce_aggregate_expr(&self, expr: &Expr, schema: &DFSchema) -> Result { + Self::coerce_aggregate_expr_impl(expr, schema) + } + + fn coerce_aggregate_expr_impl(expr: &Expr, schema: &DFSchema) -> Result { use datafusion::logical_expr::{expr::AggregateFunction, Expr, TypeSignature}; match expr { @@ -1937,6 +2105,7 @@ impl Scanner { let func = &agg_func.func; let args = &agg_func.params.args; + // Only UserDefined signature functions need explicit coercion if !matches!(func.signature().type_signature, TypeSignature::UserDefined) { return Ok(expr.clone()); } @@ -1973,7 +2142,18 @@ impl Scanner { agg_func.params.null_treatment, ))) } - other => Ok(other.clone()), + Expr::Alias(alias) => { + // Recursively coerce the inner expression and preserve the alias + let coerced_inner = Self::coerce_aggregate_expr_impl(&alias.expr, schema)?; + Ok(coerced_inner.alias(&alias.name)) + } + other => Err(Error::invalid_input( + format!( + "Expected aggregate function expression, got {:?}", + other.variant_name() + ), + location!(), + )), } } @@ -2062,6 +2242,21 @@ impl Scanner { }); } + if self.aggregate.is_some() { + if self.limit.is_some() || self.offset.is_some() { + return Err(Error::InvalidInput { + source: "Cannot use limit/offset with aggregate. Apply limit after aggregation instead.".into(), + location: location!(), + }); + } + if self.ordering.is_some() { + return Err(Error::InvalidInput { + source: "Cannot use order_by with aggregate. Apply ordering after aggregation instead.".into(), + location: location!(), + }); + } + } + Ok(()) } @@ -2306,6 +2501,13 @@ impl Scanner { // Stage 2: filter plan = filter_plan.refine_filter(plan, self).await?; + // Stage 2.5: aggregate (if set, applies aggregate and returns early) + if let Some(agg_spec) = &self.aggregate { + // Take columns needed for aggregation + plan = self.take(plan, self.projection_plan.physical_projection.clone())?; + return self.apply_aggregate(plan, agg_spec).await; + } + // Stage 3: sort if let Some(ordering) = &self.ordering { let ordering_columns = ordering.iter().map(|col| &col.column_name); @@ -2957,7 +3159,7 @@ impl Scanner { AggregateMode::Single, PhysicalGroupBy::new_single(group_expr), vec![Arc::new( - AggregateExprBuilder::new( + datafusion_physical_expr::aggregate::AggregateExprBuilder::new( functions_aggregate::min_max::max_udaf(), vec![expressions::col(SCORE_COL, &schema)?], ) diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index 03afd44c94f..e0b7d512e88 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -210,7 +210,7 @@ async fn execute_aggregate( let mut scanner = dataset.scan(); scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); - let plan = scanner.create_aggregate_plan().await?; + let plan = scanner.create_plan().await?; let stream = execute_plan(plan, LanceExecutionOptions::default())?; stream.try_collect().await.map_err(|e| e.into()) } @@ -225,7 +225,7 @@ async fn execute_aggregate_on_fragments( scanner.with_fragments(fragments); scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); - let plan = scanner.create_aggregate_plan().await?; + let plan = scanner.create_plan().await?; let stream = execute_plan(plan, LanceExecutionOptions::default())?; stream.try_collect().await.map_err(|e| e.into()) } @@ -649,7 +649,7 @@ async fn test_aggregate_with_filter() { ); scanner.aggregate(AggregateExpr::substrait(agg_bytes)); - let plan = scanner.create_aggregate_plan().await.unwrap(); + let plan = scanner.create_plan().await.unwrap(); let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); let results: Vec = stream.try_collect().await.unwrap(); @@ -686,7 +686,7 @@ async fn test_aggregate_empty_result() { ); scanner.aggregate(AggregateExpr::substrait(agg_bytes)); - let plan = scanner.create_aggregate_plan().await.unwrap(); + let plan = scanner.create_plan().await.unwrap(); let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); let results: Vec = stream.try_collect().await.unwrap(); From 576cdc0674d6d6c09e041f19a3d1b38f4b348b0c Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Thu, 12 Feb 2026 18:19:34 -0800 Subject: [PATCH 9/9] address more comments --- rust/lance/src/dataset/scanner.rs | 222 +++++++++-------- .../src/dataset/tests/dataset_aggregate.rs | 225 +++++++++++++++++- 2 files changed, 334 insertions(+), 113 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index e4cddd542e4..8923bc03e99 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -487,7 +487,7 @@ impl AggregateExpr { /// .build(); /// scanner.aggregate(agg); /// ``` - pub fn builder() -> AggregateExprBuilder { + pub fn builder() -> AggregateExprBuilder { AggregateExprBuilder::new() } @@ -533,155 +533,150 @@ impl AggregateExpr { } /// Builder for creating aggregate expressions without using DataFusion or Substrait directly. -#[derive(Debug, Clone, Default)] -pub struct AggregateExprBuilder { +/// +/// The const generic `HAS_PENDING` tracks whether there's a pending aggregate that can be aliased. +/// When `HAS_PENDING` is `true`, the last item in `aggregates` is the pending aggregate. +#[derive(Debug, Clone)] +pub struct AggregateExprBuilder { group_by: Vec, aggregates: Vec, } -impl AggregateExprBuilder { +impl Default for AggregateExprBuilder { + fn default() -> Self { + Self { + group_by: Vec::new(), + aggregates: Vec::new(), + } + } +} + +impl AggregateExprBuilder { /// Create a new builder. pub fn new() -> Self { Self::default() } + /// Build the aggregate expression. + pub fn build(self) -> AggregateExpr { + AggregateExpr::Datafusion { + group_by: self.group_by, + aggregates: self.aggregates, + } + } +} + +impl AggregateExprBuilder { /// Add a column to group by. - pub fn group_by(mut self, column: impl Into) -> Self { + /// + /// Multiple invocations will add to the list (not replace it). + /// E.g. `.group_by("x").group_by("y")` will group by both `x` and `y`. + pub fn group_by(mut self, column: impl Into) -> AggregateExprBuilder { self.group_by.push(col(column.into())); - self + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } } /// Add multiple columns to group by. + /// + /// Multiple invocations will add to the list (not replace it). + /// E.g. `.group_by("x").group_by_columns(["y", "z"])` will group by `x`, `y`, and `z`. pub fn group_by_columns( mut self, columns: impl IntoIterator>, - ) -> Self { + ) -> AggregateExprBuilder { for column in columns { self.group_by.push(col(column.into())); } - self + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } } - /// Add COUNT(*) aggregate. - pub fn count_star(self) -> AggregateExprBuilderWithPendingAggregate { - AggregateExprBuilderWithPendingAggregate { - builder: self, - pending: functions_aggregate::count::count(lit(1)), + /// Add COUNT(*) aggregate that counts all rows. + pub fn count_star(mut self) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::count::count(lit(1))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, } } /// Add COUNT(column) aggregate. - pub fn count(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { - AggregateExprBuilderWithPendingAggregate { - builder: self, - pending: functions_aggregate::count::count(col(column.into())), + /// + /// Unlike `count_star`, this will only count the number of rows where `column` + /// is not NULL. + pub fn count(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::count::count(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, } } /// Add SUM(column) aggregate. - pub fn sum(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { - AggregateExprBuilderWithPendingAggregate { - builder: self, - pending: functions_aggregate::sum::sum(col(column.into())), + pub fn sum(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::sum::sum(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, } } /// Add AVG(column) aggregate. - pub fn avg(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { - AggregateExprBuilderWithPendingAggregate { - builder: self, - pending: functions_aggregate::average::avg(col(column.into())), + pub fn avg(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::average::avg(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, } } /// Add MIN(column) aggregate. - pub fn min(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { - AggregateExprBuilderWithPendingAggregate { - builder: self, - pending: functions_aggregate::min_max::min(col(column.into())), + pub fn min(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::min_max::min(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, } } /// Add MAX(column) aggregate. - pub fn max(self, column: impl Into) -> AggregateExprBuilderWithPendingAggregate { - AggregateExprBuilderWithPendingAggregate { - builder: self, - pending: functions_aggregate::min_max::max(col(column.into())), - } - } - - /// Build the aggregate expression. - pub fn build(self) -> AggregateExpr { - AggregateExpr::Datafusion { + pub fn max(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::min_max::max(col(column.into()))); + AggregateExprBuilder { group_by: self.group_by, aggregates: self.aggregates, } } } -/// Builder state with a pending aggregate that can be aliased. -#[derive(Debug, Clone)] -pub struct AggregateExprBuilderWithPendingAggregate { - builder: AggregateExprBuilder, - pending: Expr, -} - -impl AggregateExprBuilderWithPendingAggregate { - /// Set an alias for the pending aggregate. - pub fn alias(mut self, name: impl Into) -> AggregateExprBuilder { - self.builder - .aggregates - .push(self.pending.alias(name.into())); - self.builder - } - - /// Add another group by column. - pub fn group_by(mut self, column: impl Into) -> AggregateExprBuilder { - self.builder.aggregates.push(self.pending); - self.builder.group_by.push(col(column.into())); - self.builder - } - - /// Add COUNT(*) aggregate. - pub fn count_star(mut self) -> Self { - self.builder.aggregates.push(self.pending); - self.builder.count_star() - } - - /// Add COUNT(column) aggregate. - pub fn count(mut self, column: impl Into) -> Self { - self.builder.aggregates.push(self.pending); - self.builder.count(column) - } - - /// Add SUM(column) aggregate. - pub fn sum(mut self, column: impl Into) -> Self { - self.builder.aggregates.push(self.pending); - self.builder.sum(column) - } - - /// Add AVG(column) aggregate. - pub fn avg(mut self, column: impl Into) -> Self { - self.builder.aggregates.push(self.pending); - self.builder.avg(column) - } - - /// Add MIN(column) aggregate. - pub fn min(mut self, column: impl Into) -> Self { - self.builder.aggregates.push(self.pending); - self.builder.min(column) - } - - /// Add MAX(column) aggregate. - pub fn max(mut self, column: impl Into) -> Self { - self.builder.aggregates.push(self.pending); - self.builder.max(column) +impl AggregateExprBuilder { + /// Set an alias for the pending aggregate (the last added aggregate). + pub fn alias(mut self, name: impl Into) -> AggregateExprBuilder { + let pending = self.aggregates.pop().expect("pending aggregate must exist"); + self.aggregates.push(pending.alias(name.into())); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } } /// Build the aggregate expression. - pub fn build(mut self) -> AggregateExpr { - self.builder.aggregates.push(self.pending); - self.builder.build() + pub fn build(self) -> AggregateExpr { + AggregateExpr::Datafusion { + group_by: self.group_by, + aggregates: self.aggregates, + } } } @@ -2245,13 +2240,17 @@ impl Scanner { if self.aggregate.is_some() { if self.limit.is_some() || self.offset.is_some() { return Err(Error::InvalidInput { - source: "Cannot use limit/offset with aggregate. Apply limit after aggregation instead.".into(), + source: + "Cannot use limit/offset with aggregate. Apply limit to the result instead." + .into(), location: location!(), }); } if self.ordering.is_some() { return Err(Error::InvalidInput { - source: "Cannot use order_by with aggregate. Apply ordering after aggregation instead.".into(), + source: + "Cannot use order_by with aggregate. Apply ordering to the result instead." + .into(), location: location!(), }); } @@ -2413,7 +2412,7 @@ impl Scanner { let mut filter_plan = self.create_filter_plan(use_scalar_index).await?; let mut use_limit_node = true; - // Stage 1: source (either an (K|A)NN search, full text search or or a (full|indexed) scan) + // Source: either a (K|A)NN search, full text search, or a (full|indexed) scan let mut plan: Arc = match (&self.nearest, &self.full_text_query) { (Some(_), None) => self.vector_search_source(&mut filter_plan).await?, (None, Some(query)) => self.fts_search_source(&mut filter_plan, query).await?, @@ -2471,8 +2470,7 @@ impl Scanner { } }; - // Stage 1.5 load columns needed for stages 2 & 3 - // Calculate the schema needed for the filter and ordering. + // Load columns needed for filter and ordering let mut pre_filter_projection = self.dataset.empty_projection(); // We may need to take filter columns if we are going to refine @@ -2498,17 +2496,17 @@ impl Scanner { plan = self.take(plan, pre_filter_projection)?; - // Stage 2: filter + // Filter plan = filter_plan.refine_filter(plan, self).await?; - // Stage 2.5: aggregate (if set, applies aggregate and returns early) + // Aggregate (if set, applies aggregate and returns early) if let Some(agg_spec) = &self.aggregate { // Take columns needed for aggregation plan = self.take(plan, self.projection_plan.physical_projection.clone())?; return self.apply_aggregate(plan, agg_spec).await; } - // Stage 3: sort + // Sort if let Some(ordering) = &self.ordering { let ordering_columns = ordering.iter().map(|col| &col.column_name); let projection_with_ordering = self @@ -2540,25 +2538,25 @@ impl Scanner { )); } - // Stage 4: limit / offset + // Limit / offset if use_limit_node && (self.limit.unwrap_or(0) > 0 || self.offset.is_some()) { plan = self.limit_node(plan); } - // Stage 5: take remaining columns required for projection + // Take remaining columns required for projection plan = self.take(plan, self.projection_plan.physical_projection.clone())?; - // Stage 6: Add system columns, if requested + // Add system columns, if requested if self.projection_plan.must_add_row_offset { plan = Arc::new(AddRowOffsetExec::try_new(plan, self.dataset.clone()).await?); } - // Stage 7: final projection + // Final projection let final_projection = self.calculate_final_projection(plan.schema().as_ref())?; plan = Arc::new(DFProjectionExec::try_new(final_projection, plan)?); - // Stage 8: If requested, apply a strict batch size to the final output + // If requested, apply a strict batch size to the final output if self.strict_batch_size { plan = Arc::new(StrictBatchSizeExec::new(plan, self.get_batch_size())); } diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index e0b7d512e88..e75595a78ce 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -7,7 +7,9 @@ use std::sync::Arc; use arrow_array::cast::AsArray; use arrow_array::types::{Float64Type, Int64Type}; -use arrow_array::{RecordBatch, RecordBatchIterator}; +use arrow_array::{ + FixedSizeListArray, Float32Array, Int64Array, RecordBatch, RecordBatchIterator, StringArray, +}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use datafusion_substrait::substrait::proto::{ aggregate_function::AggregationInvocation, @@ -35,8 +37,14 @@ use prost::Message; use tempfile::tempdir; use crate::dataset::scanner::AggregateExpr; +use crate::index::vector::VectorIndexParams; use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; use crate::Dataset; +use lance_arrow::FixedSizeListArrayExt; +use lance_index::scalar::inverted::InvertedIndexParams; +use lance_index::scalar::FullTextSearchQuery; +use lance_index::{DatasetIndexExt, IndexType}; +use lance_linalg::distance::MetricType; /// Helper to create a field reference expression for a column index fn field_ref(field_index: i32) -> Expression { @@ -919,3 +927,218 @@ async fn test_first_value_with_order_by_desc() { assert_eq!(results_map.get(&2), Some(&7)); assert_eq!(results_map.get(&3), Some(&8)); } + +/// Create a dataset with vectors, text, and category for vector search and FTS aggregate tests. +/// Schema: id (i64), vec (fixed_size_list[4]), text (utf8), category (utf8) +async fn create_vector_text_dataset(uri: &str, num_rows: i64) -> Dataset { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new( + "vec", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + true, + ), + Field::new("text", DataType::Utf8, false), + Field::new("category", DataType::Utf8, false), + ])); + + let ids: Vec = (0..num_rows).collect(); + let vectors: Vec = (0..num_rows).flat_map(|i| vec![i as f32; 4]).collect(); + let texts: Vec = (0..num_rows).map(|i| format!("document {}", i)).collect(); + let categories: Vec = (0..num_rows) + .map(|i| match i % 3 { + 0 => "category_a".to_string(), + 1 => "category_b".to_string(), + _ => "category_c".to_string(), + }) + .collect(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new( + FixedSizeListArray::try_new_from_values(Float32Array::from(vectors), 4).unwrap(), + ), + Arc::new(StringArray::from(texts)), + Arc::new(StringArray::from(categories)), + ], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + Dataset::write(reader, uri, None).await.unwrap() +} + +#[tokio::test] +async fn test_vector_search_with_aggregate() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create vector index + let params = VectorIndexParams::ivf_flat(2, MetricType::L2); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + // Vector search for top 30 results, then aggregate by category with COUNT(*) + // Query vector close to id=50 (vec=[50,50,50,50]) + let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]); + + // COUNT(*) GROUP BY category (column index 3) + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![field_ref(3)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "count")], + vec!["category".to_string(), "count".to_string()], + ); + + let mut scanner = dataset.scan(); + scanner + .nearest("vec", &query_vector, 30) + .unwrap() + .project(&["id", "category"]) + .unwrap() + .aggregate(AggregateExpr::substrait(agg_bytes)); + + let results = scanner.try_into_batch().await.unwrap(); + + // Should have 3 categories (or fewer if search results don't cover all) + assert!( + results.num_rows() >= 1 && results.num_rows() <= 3, + "Expected 1-3 rows but got {}", + results.num_rows() + ); + + // Total count should be 30 (top K results) + let counts: Vec = results + .column(1) + .as_primitive::() + .values() + .to_vec(); + let total: i64 = counts.iter().sum(); + assert_eq!(total, 30); +} + +#[tokio::test] +async fn test_fts_with_aggregate() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create FTS index on text column + dataset + .create_index( + &["text"], + IndexType::Inverted, + None, + &InvertedIndexParams::default(), + true, + ) + .await + .unwrap(); + + // FTS search for "document", then aggregate by category with COUNT(*) + // All documents match "document" so we should get all 100 rows + // COUNT(*) GROUP BY category (column index 3) + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![field_ref(3)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "count")], + vec!["category".to_string(), "count".to_string()], + ); + + let mut scanner = dataset.scan(); + scanner + .full_text_search(FullTextSearchQuery::new("document".to_string())) + .unwrap() + .project(&["id", "category"]) + .unwrap() + .aggregate(AggregateExpr::substrait(agg_bytes)); + + let results = scanner.try_into_batch().await.unwrap(); + + // Should have 3 categories + assert_eq!( + results.num_rows(), + 3, + "Expected 3 rows but got {}", + results.num_rows() + ); + + // Total count should be 100 (all documents match "document") + let counts: Vec = results + .column(1) + .as_primitive::() + .values() + .to_vec(); + let total: i64 = counts.iter().sum(); + assert_eq!(total, 100); + + // Each category should have ~33 rows (100/3) + for count in &counts { + assert!(*count >= 33 && *count <= 34); + } +} + +#[tokio::test] +async fn test_vector_search_with_sum_aggregate() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create vector index + let params = VectorIndexParams::ivf_flat(2, MetricType::L2); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + // Vector search for top 10 results, then SUM(id) GROUP BY category + let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]); + + // SUM(id) GROUP BY category + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], // SUM(id) - column 0 + vec![field_ref(3)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "sum")], + vec!["category".to_string(), "sum_id".to_string()], + ); + + let mut scanner = dataset.scan(); + scanner + .nearest("vec", &query_vector, 10) + .unwrap() + .project(&["id", "category"]) + .unwrap() + .aggregate(AggregateExpr::substrait(agg_bytes)); + + let results = scanner.try_into_batch().await.unwrap(); + + // Should have results grouped by category (1-3 depending on which categories are in top K) + assert!( + results.num_rows() >= 1 && results.num_rows() <= 3, + "Expected 1-3 rows but got {}", + results.num_rows() + ); + + // Verify we have 2 columns: category and sum_id + assert_eq!(results.num_columns(), 2); +}