diff --git a/Cargo.lock b/Cargo.lock index cc98925d15e..cfcc4899c96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4731,6 +4731,7 @@ dependencies = [ "datafusion-functions", "datafusion-physical-expr", "datafusion-physical-plan", + "datafusion-substrait", "deepsize", "dirs 5.0.1", "either", diff --git a/rust/lance-datafusion/src/aggregate.rs b/rust/lance-datafusion/src/aggregate.rs new file mode 100644 index 00000000000..5528104c044 --- /dev/null +++ b/rust/lance-datafusion/src/aggregate.rs @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Aggregate specification for DataFusion aggregates. + +use datafusion::logical_expr::Expr; + +/// Aggregate specification with group by and aggregate expressions. +#[derive(Debug, Clone)] +pub struct Aggregate { + /// Expressions to group by (e.g., column references). + pub group_by: Vec, + /// Aggregate function expressions (e.g., SUM, COUNT, AVG). + /// Use `.alias()` on the expression to set output column names. + pub aggregates: Vec, +} diff --git a/rust/lance-datafusion/src/lib.rs b/rust/lance-datafusion/src/lib.rs index fa65a918191..0ef51216ae8 100644 --- a/rust/lance-datafusion/src/lib.rs +++ b/rust/lance-datafusion/src/lib.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +pub mod aggregate; pub mod chunker; pub mod dataframe; pub mod datagen; diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index 54dc9be8808..2f84a266f65 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -3,6 +3,13 @@ use arrow_schema::Schema as ArrowSchema; use datafusion::{execution::SessionState, logical_expr::Expr}; + +use crate::aggregate::Aggregate; +use datafusion_common::DFSchema; +use datafusion_substrait::extensions::Extensions; +use datafusion_substrait::logical_plan::consumer::{ + from_substrait_agg_func, from_substrait_rex, from_substrait_sorts, DefaultSubstraitConsumer, +}; use datafusion_substrait::substrait::proto::{ expression::{ field_reference::{ReferenceType, RootType}, @@ -11,7 +18,8 @@ use datafusion_substrait::substrait::proto::{ expression_reference::ExprType, function_argument::ArgType, r#type::{Kind, Struct}, - Expression, ExpressionReference, ExtendedExpression, NamedStruct, Type, + rel::RelType, + AggregateRel, Expression, ExpressionReference, ExtendedExpression, NamedStruct, Plan, Type, }; use lance_core::{Error, Result}; use prost::Message; @@ -324,6 +332,217 @@ pub async fn parse_substrait( Ok(expr_container.exprs.pop().unwrap().0) } +/// Parse Substrait Plan bytes containing an AggregateRel. +pub async fn parse_substrait_aggregate( + bytes: &[u8], + input_schema: Arc, + state: &SessionState, +) -> Result { + let plan = Plan::decode(bytes)?; + let (aggregate_rel, output_names) = extract_aggregate_from_plan(&plan)?; + let extensions = Extensions::try_from(&plan.extensions)?; + + let mut agg = + parse_aggregate_rel_with_extensions(&aggregate_rel, input_schema, state, &extensions) + .await?; + + // Apply aliases from RelRoot.names to expressions + if !output_names.is_empty() { + let num_groups = agg.group_by.len(); + for (i, expr) in agg.group_by.iter_mut().enumerate() { + if i < output_names.len() { + *expr = expr.clone().alias(&output_names[i]); + } + } + for (i, expr) in agg.aggregates.iter_mut().enumerate() { + let name_idx = num_groups + i; + if name_idx < output_names.len() { + *expr = expr.clone().alias(&output_names[name_idx]); + } + } + } + + Ok(agg) +} + +fn extract_aggregate_from_plan(plan: &Plan) -> Result<(Box, Vec)> { + if plan.relations.is_empty() { + return Err(Error::invalid_input( + "Substrait Plan has no relations", + location!(), + )); + } + + let plan_rel = &plan.relations[0]; + let (rel, output_names) = match &plan_rel.rel_type { + Some(datafusion_substrait::substrait::proto::plan_rel::RelType::Root(root)) => { + (root.input.as_ref(), root.names.clone()) + } + Some(datafusion_substrait::substrait::proto::plan_rel::RelType::Rel(rel)) => { + (Some(rel), vec![]) + } + None => (None, vec![]), + }; + + let rel = rel.ok_or_else(|| Error::invalid_input("Plan relation has no input", location!()))?; + + match &rel.rel_type { + Some(RelType::Aggregate(agg)) => Ok((agg.clone(), output_names)), + Some(other) => Err(Error::invalid_input( + format!( + "Expected Substrait AggregateRel, got {:?}", + std::mem::discriminant(other) + ), + location!(), + )), + None => Err(Error::invalid_input( + "Substrait Rel has no rel_type", + location!(), + )), + } +} + +/// Parse an AggregateRel proto with provided extensions. +pub async fn parse_aggregate_rel_with_extensions( + aggregate_rel: &AggregateRel, + input_schema: Arc, + state: &SessionState, + extensions: &Extensions, +) -> Result { + let df_schema = DFSchema::try_from(input_schema.as_ref().clone())?; + let consumer = DefaultSubstraitConsumer::new(extensions, state); + let group_by = parse_groupings(aggregate_rel, &df_schema, &consumer).await?; + let aggregates = parse_measures(aggregate_rel, &df_schema, &consumer).await?; + + Ok(Aggregate { + group_by, + aggregates, + }) +} + +/// Parse an AggregateRel proto with default extensions. +pub async fn parse_aggregate_rel( + aggregate_rel: &AggregateRel, + input_schema: Arc, + state: &SessionState, +) -> Result { + let extensions = Extensions::default(); + parse_aggregate_rel_with_extensions(aggregate_rel, input_schema, state, &extensions).await +} + +async fn parse_groupings( + agg_rel: &AggregateRel, + schema: &DFSchema, + consumer: &DefaultSubstraitConsumer<'_>, +) -> Result> { + let mut group_exprs = Vec::new(); + + // First, handle the new-style grouping_expressions + expression_references + if !agg_rel.grouping_expressions.is_empty() { + for grouping in &agg_rel.groupings { + for expr_ref in &grouping.expression_references { + let idx = *expr_ref as usize; + if idx >= agg_rel.grouping_expressions.len() { + return Err(Error::invalid_input( + format!( + "Grouping expression reference {} out of bounds (max: {})", + idx, + agg_rel.grouping_expressions.len() + ), + location!(), + )); + } + let expr = &agg_rel.grouping_expressions[idx]; + let df_expr = from_substrait_rex(consumer, expr, schema) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse grouping expression: {}", e), + location!(), + ) + })?; + group_exprs.push(df_expr); + } + } + } else { + // Fallback to deprecated inline grouping_expressions within each Grouping + #[allow(deprecated)] + for grouping in &agg_rel.groupings { + for expr in &grouping.grouping_expressions { + let df_expr = from_substrait_rex(consumer, expr, schema) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse grouping expression: {}", e), + location!(), + ) + })?; + group_exprs.push(df_expr); + } + } + } + + Ok(group_exprs) +} + +async fn parse_measures( + agg_rel: &AggregateRel, + schema: &DFSchema, + consumer: &DefaultSubstraitConsumer<'_>, +) -> Result> { + let mut aggregates = Vec::new(); + + for measure in &agg_rel.measures { + if let Some(agg_func) = &measure.measure { + // Parse optional filter + let filter = if let Some(filter_expr) = &measure.filter { + let df_filter = from_substrait_rex(consumer, filter_expr, schema) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse measure filter: {}", e), + location!(), + ) + })?; + Some(Box::new(df_filter)) + } else { + None + }; + + // Parse ordering (for ordered aggregates like ARRAY_AGG) + let order_by = from_substrait_sorts(consumer, &agg_func.sorts, schema) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse aggregate sorts: {}", e), + location!(), + ) + })?; + + // Check for DISTINCT invocation + let distinct = matches!( + agg_func.invocation, + i if i == datafusion_substrait::substrait::proto::aggregate_function::AggregationInvocation::Distinct as i32 + ); + + // Convert Substrait AggregateFunction to DataFusion Expr + let df_expr = + from_substrait_agg_func(consumer, agg_func, schema, filter, order_by, distinct) + .await + .map_err(|e| { + Error::invalid_input( + format!("Failed to parse aggregate function: {}", e), + location!(), + ) + })?; + + aggregates.push(df_expr.as_ref().clone()); + } + } + + Ok(aggregates) +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -607,4 +826,281 @@ mod tests { assert_substrait_roundtrip(schema, id_filter("test-id")).await; } + + // ==================== Aggregate parsing tests ==================== + + use datafusion_substrait::substrait::proto::{ + aggregate_function::AggregationInvocation, + aggregate_rel::{Grouping, Measure}, + rel::RelType, + AggregateFunction, AggregateRel, Plan, PlanRel, Rel, RelRoot, + }; + + /// Helper to create a field reference expression for a column index + fn agg_field_ref(field_index: i32) -> Expression { + Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField(Box::new( + StructField { + field: field_index, + child: None, + }, + ))), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }))), + } + } + + /// Create extension declaration for an aggregate function + fn agg_extension(anchor: u32, name: &str) -> SimpleExtensionDeclaration { + SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction { + #[allow(deprecated)] + extension_uri_reference: 1, + extension_urn_reference: 0, + function_anchor: anchor, + name: name.to_string(), + })), + } + } + + /// Helper to create a Substrait Plan with AggregateRel + fn create_aggregate_plan( + measures: Vec, + grouping_expressions: Vec, + groupings: Vec, + extensions: Vec, + ) -> Vec { + let aggregate_rel = AggregateRel { + common: None, + input: None, // Input is ignored for pushdown + groupings, + measures, + grouping_expressions, + advanced_extension: None, + }; + + let rel = Rel { + rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))), + }; + + // Wrap in a Plan to include extensions + let plan = Plan { + version: Some(Version { + major_number: 0, + minor_number: 63, + patch_number: 0, + git_hash: String::new(), + producer: "lance-test".to_string(), + }), + #[allow(deprecated)] + extension_uris: vec![SimpleExtensionUri { + extension_uri_anchor: 1, + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml".to_string(), + }], + extensions, + relations: vec![PlanRel { + rel_type: Some( + datafusion_substrait::substrait::proto::plan_rel::RelType::Root(RelRoot { + input: Some(rel), + names: vec![], + }), + ), + }], + advanced_extensions: None, + expected_type_urls: vec![], + extension_urns: vec![], + parameter_bindings: vec![], + type_aliases: vec![], + }; + + plan.encode_to_vec() + } + + /// Create a COUNT(*) measure + fn count_star_measure(function_ref: u32) -> Measure { + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![], + options: vec![], + output_type: None, + phase: 0, + sorts: vec![], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } + } + + /// Create a SUM/AVG/MIN/MAX measure on a column + fn simple_agg_measure(function_ref: u32, column_index: i32) -> Measure { + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(agg_field_ref(column_index))), + }], + options: vec![], + output_type: None, + phase: 0, + sorts: vec![], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } + } + + #[tokio::test] + async fn test_parse_substrait_aggregate_count_star() { + let bytes = create_aggregate_plan( + vec![count_star_measure(0)], + vec![], + vec![], + vec![agg_extension(0, "count")], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int64, true), + ])); + + let result = + crate::substrait::parse_substrait_aggregate(&bytes, schema, &session_state()).await; + + let agg = result.expect("Failed to parse COUNT(*) aggregate"); + assert!(agg.group_by.is_empty(), "COUNT(*) should have no group by"); + assert_eq!(agg.aggregates.len(), 1, "Should have exactly one aggregate"); + + // Verify it's a COUNT aggregate + let agg_expr = &agg.aggregates[0]; + assert!( + agg_expr.schema_name().to_string().contains("count"), + "Expected COUNT aggregate, got: {}", + agg_expr.schema_name() + ); + } + + #[tokio::test] + async fn test_parse_substrait_aggregate_sum() { + let bytes = create_aggregate_plan( + vec![simple_agg_measure(0, 1)], // SUM on column index 1 (y) + vec![], + vec![], + vec![agg_extension(0, "sum")], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int64, true), + ])); + + let result = + crate::substrait::parse_substrait_aggregate(&bytes, schema, &session_state()).await; + + let agg = result.expect("Failed to parse SUM aggregate"); + assert!(agg.group_by.is_empty(), "SUM should have no group by"); + assert_eq!(agg.aggregates.len(), 1, "Should have exactly one aggregate"); + + // Verify it's a SUM aggregate + let agg_expr = &agg.aggregates[0]; + assert!( + agg_expr.schema_name().to_string().contains("sum"), + "Expected SUM aggregate, got: {}", + agg_expr.schema_name() + ); + } + + #[tokio::test] + async fn test_parse_substrait_aggregate_sum_with_group_by() { + // SUM(y) GROUP BY x + let bytes = create_aggregate_plan( + vec![simple_agg_measure(0, 1)], // SUM on column index 1 (y) + vec![agg_field_ref(0)], // Group by column index 0 (x) + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], // Reference to first grouping_expression + }], + vec![agg_extension(0, "sum")], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int64, true), + ])); + + let result = + crate::substrait::parse_substrait_aggregate(&bytes, schema, &session_state()).await; + + let agg = result.expect("Failed to parse SUM with GROUP BY"); + assert_eq!( + agg.group_by.len(), + 1, + "Should have exactly one group by expression" + ); + assert_eq!(agg.aggregates.len(), 1, "Should have exactly one aggregate"); + + // Verify group by is column x + let group_expr = &agg.group_by[0]; + assert!( + group_expr.schema_name().to_string().contains('x'), + "Expected group by on column x, got: {}", + group_expr.schema_name() + ); + + // Verify it's a SUM aggregate + let agg_expr = &agg.aggregates[0]; + assert!( + agg_expr.schema_name().to_string().contains("sum"), + "Expected SUM aggregate, got: {}", + agg_expr.schema_name() + ); + } + + #[tokio::test] + async fn test_parse_substrait_aggregate_multiple_aggregates() { + // COUNT(*) and SUM(y) + let bytes = create_aggregate_plan( + vec![count_star_measure(0), simple_agg_measure(1, 1)], + vec![], + vec![], + vec![agg_extension(0, "count"), agg_extension(1, "sum")], + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int64, true), + ])); + + let result = + crate::substrait::parse_substrait_aggregate(&bytes, schema, &session_state()).await; + + let agg = result.expect("Failed to parse multiple aggregates"); + assert!(agg.group_by.is_empty(), "Should have no group by"); + assert_eq!(agg.aggregates.len(), 2, "Should have two aggregates"); + + // Verify COUNT + assert!( + agg.aggregates[0] + .schema_name() + .to_string() + .contains("count"), + "Expected COUNT aggregate, got: {}", + agg.aggregates[0].schema_name() + ); + + // Verify SUM + assert!( + agg.aggregates[1].schema_name().to_string().contains("sum"), + "Expected SUM aggregate, got: {}", + agg.aggregates[1].schema_name() + ); + } } diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index d57187bcacf..427de2aa30e 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -116,6 +116,7 @@ aws-sdk-s3 = { workspace = true } geoarrow-array = { workspace = true } geoarrow-schema = { workspace = true } geo-types = { workspace = true } +datafusion-substrait = { workspace = true } [features] default = ["aws", "azure", "gcp", "oss", "huggingface", "tencent"] diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index a9265517fe5..8923bc03e99 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -34,7 +34,7 @@ use datafusion::scalar::ScalarValue; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::ExprSchemable; use datafusion_functions::core::getfield::GetFieldFunc; -use datafusion_physical_expr::{aggregate::AggregateExprBuilder, expressions::Column}; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{create_physical_expr, LexOrdering, Partitioning, PhysicalExpr}; use datafusion_physical_plan::joins::PartitionMode; use datafusion_physical_plan::projection::ProjectionExec; @@ -53,6 +53,7 @@ use lance_core::utils::address::RowAddress; use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap}; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{ROW_ADDR, ROW_ID, ROW_OFFSET}; +use lance_datafusion::aggregate::Aggregate; use lance_datafusion::exec::{ analyze_plan, execute_plan, LanceExecutionOptions, OneShotExec, StrictBatchSizeExec, }; @@ -462,6 +463,223 @@ impl ExprFilter { } } +/// Aggregate expression from Substrait or DataFusion. +#[derive(Debug, Clone)] +pub enum AggregateExpr { + #[cfg(feature = "substrait")] + Substrait(Vec), + Datafusion { + group_by: Vec, + aggregates: Vec, + }, +} + +impl AggregateExpr { + /// Create a new builder for aggregate expressions. + /// + /// # Example + /// ```ignore + /// let agg = AggregateExpr::builder() + /// .group_by("category") + /// .count_star().alias("total_count") + /// .sum("amount").alias("total_amount") + /// .avg("price") + /// .build(); + /// scanner.aggregate(agg); + /// ``` + pub fn builder() -> AggregateExprBuilder { + AggregateExprBuilder::new() + } + + /// Create from Substrait Plan bytes. + #[cfg(feature = "substrait")] + pub fn substrait(bytes: impl Into>) -> Self { + Self::Substrait(bytes.into()) + } + + /// Create from DataFusion expressions. + /// Use `.alias()` on expressions to set output column names. + pub fn datafusion(group_by: Vec, aggregates: Vec) -> Self { + Self::Datafusion { + group_by, + aggregates, + } + } + + fn to_aggregate( + &self, + #[allow(unused_variables)] schema: Arc, + ) -> Result { + match self { + #[cfg(feature = "substrait")] + Self::Substrait(bytes) => { + use lance_datafusion::exec::{get_session_context, LanceExecutionOptions}; + use lance_datafusion::substrait::parse_substrait_aggregate; + + let ctx = get_session_context(&LanceExecutionOptions::default()); + parse_substrait_aggregate(bytes, schema, &ctx.state()) + .now_or_never() + .expect("could not parse the Substrait aggregate in a synchronous fashion") + } + Self::Datafusion { + group_by, + aggregates, + } => Ok(Aggregate { + group_by: group_by.clone(), + aggregates: aggregates.clone(), + }), + } + } +} + +/// Builder for creating aggregate expressions without using DataFusion or Substrait directly. +/// +/// The const generic `HAS_PENDING` tracks whether there's a pending aggregate that can be aliased. +/// When `HAS_PENDING` is `true`, the last item in `aggregates` is the pending aggregate. +#[derive(Debug, Clone)] +pub struct AggregateExprBuilder { + group_by: Vec, + aggregates: Vec, +} + +impl Default for AggregateExprBuilder { + fn default() -> Self { + Self { + group_by: Vec::new(), + aggregates: Vec::new(), + } + } +} + +impl AggregateExprBuilder { + /// Create a new builder. + pub fn new() -> Self { + Self::default() + } + + /// Build the aggregate expression. + pub fn build(self) -> AggregateExpr { + AggregateExpr::Datafusion { + group_by: self.group_by, + aggregates: self.aggregates, + } + } +} + +impl AggregateExprBuilder { + /// Add a column to group by. + /// + /// Multiple invocations will add to the list (not replace it). + /// E.g. `.group_by("x").group_by("y")` will group by both `x` and `y`. + pub fn group_by(mut self, column: impl Into) -> AggregateExprBuilder { + self.group_by.push(col(column.into())); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } + + /// Add multiple columns to group by. + /// + /// Multiple invocations will add to the list (not replace it). + /// E.g. `.group_by("x").group_by_columns(["y", "z"])` will group by `x`, `y`, and `z`. + pub fn group_by_columns( + mut self, + columns: impl IntoIterator>, + ) -> AggregateExprBuilder { + for column in columns { + self.group_by.push(col(column.into())); + } + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } + + /// Add COUNT(*) aggregate that counts all rows. + pub fn count_star(mut self) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::count::count(lit(1))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } + + /// Add COUNT(column) aggregate. + /// + /// Unlike `count_star`, this will only count the number of rows where `column` + /// is not NULL. + pub fn count(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::count::count(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } + + /// Add SUM(column) aggregate. + pub fn sum(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::sum::sum(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } + + /// Add AVG(column) aggregate. + pub fn avg(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::average::avg(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } + + /// Add MIN(column) aggregate. + pub fn min(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::min_max::min(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } + + /// Add MAX(column) aggregate. + pub fn max(mut self, column: impl Into) -> AggregateExprBuilder { + self.aggregates + .push(functions_aggregate::min_max::max(col(column.into()))); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } +} + +impl AggregateExprBuilder { + /// Set an alias for the pending aggregate (the last added aggregate). + pub fn alias(mut self, name: impl Into) -> AggregateExprBuilder { + let pending = self.aggregates.pop().expect("pending aggregate must exist"); + self.aggregates.push(pending.alias(name.into())); + AggregateExprBuilder { + group_by: self.group_by, + aggregates: self.aggregates, + } + } + + /// Build the aggregate expression. + pub fn build(self) -> AggregateExpr { + AggregateExpr::Datafusion { + group_by: self.group_by, + aggregates: self.aggregates, + } + } +} + /// Dataset Scanner /// /// ```rust,ignore @@ -570,6 +788,8 @@ pub struct Scanner { /// File reader options to use when reading data files. file_reader_options: Option, + aggregate: Option, + // Legacy fields to help migrate some old projection behavior to new behavior // // There are two behaviors we are moving away from: @@ -787,6 +1007,7 @@ impl Scanner { scan_stats_callback: None, strict_batch_size: false, file_reader_options, + aggregate: None, legacy_with_row_addr: false, legacy_with_row_id: false, explicit_projection: false, @@ -1017,6 +1238,12 @@ impl Scanner { self } + /// Set aggregation. + pub fn aggregate(&mut self, aggregate: AggregateExpr) -> &mut Self { + self.aggregate = Some(aggregate); + self + } + /// Set the batch size. pub fn batch_size(&mut self, batch_size: usize) -> &mut Self { self.batch_size = Some(batch_size); @@ -1718,7 +1945,10 @@ impl Scanner { let input_phy_exprs: &[Arc] = &[one]; let schema = plan.schema(); - let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); + let mut builder = datafusion_physical_expr::aggregate::AggregateExprBuilder::new( + count_udaf(), + input_phy_exprs.to_vec(), + ); builder = builder.schema(schema); builder = builder.alias("count_rows".to_string()); @@ -1767,6 +1997,161 @@ impl Scanner { .boxed() } + /// Create an execution plan with aggregation. + /// + /// Requires `aggregate()` to be called first. + #[deprecated(note = "Use create_plan() instead, which now applies aggregate automatically")] + pub fn create_aggregate_plan(&self) -> BoxFuture<'_, Result>> { + async move { + if self.aggregate.is_none() { + return Err(Error::invalid_input( + "create_aggregate_plan called but no aggregate was set", + location!(), + )); + } + // create_plan() now applies aggregate automatically when set + self.create_plan().await + } + .boxed() + } + + async fn apply_aggregate( + &self, + plan: Arc, + agg_spec: &AggregateExpr, + ) -> Result> { + use datafusion_physical_expr::aggregate::AggregateFunctionExpr; + + let schema = plan.schema(); + let agg = agg_spec.to_aggregate(schema.clone())?; + let df_schema = DFSchema::try_from(schema.as_ref().clone())?; + + let group_exprs: Vec<(Arc, String)> = agg + .group_by + .iter() + .map(|expr| { + let name = expr.schema_name().to_string(); + let physical_expr = + create_physical_expr(expr, &df_schema, &ExecutionProps::default())?; + Ok((physical_expr, name)) + }) + .collect::>()?; + + #[allow(clippy::type_complexity)] + let aggr_results: Vec<(Arc, Option>)> = agg + .aggregates + .iter() + .map(|expr| self.build_physical_aggregate_expr(expr, &df_schema, &schema)) + .collect::>()?; + + let (aggr_exprs, filters): (Vec<_>, Vec<_>) = aggr_results.into_iter().unzip(); + + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(group_exprs), + aggr_exprs, + filters, + plan, + schema, + )?) as Arc) + } + + #[allow(clippy::type_complexity)] + fn build_physical_aggregate_expr( + &self, + expr: &Expr, + df_schema: &DFSchema, + input_schema: &SchemaRef, + ) -> Result<( + Arc, + Option>, + )> { + use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter; + + let coerced_expr = self.coerce_aggregate_expr(expr, df_schema)?; + + // Note: order_by is already embedded in the AggregateFunctionExpr for ordered aggregates + let (agg_expr, filter, _order_by) = create_aggregate_expr_and_maybe_filter( + &coerced_expr, + df_schema, + input_schema.as_ref(), + &ExecutionProps::default(), + )?; + + Ok((agg_expr, filter)) + } + + /// Apply type coercion to aggregate arguments for UserDefined signature functions. + /// + /// Most aggregate functions (SUM, COUNT, MIN, MAX) have explicit type signatures that + /// DataFusion handles automatically. However, some functions like AVG use UserDefined + /// type signatures in the Substrait consumer, which means DataFusion doesn't know the + /// expected input types and won't perform automatic coercion. We must explicitly coerce + /// arguments to the types returned by `func.coerce_types()`. + fn coerce_aggregate_expr(&self, expr: &Expr, schema: &DFSchema) -> Result { + Self::coerce_aggregate_expr_impl(expr, schema) + } + + fn coerce_aggregate_expr_impl(expr: &Expr, schema: &DFSchema) -> Result { + use datafusion::logical_expr::{expr::AggregateFunction, Expr, TypeSignature}; + + match expr { + Expr::AggregateFunction(agg_func) => { + let func = &agg_func.func; + let args = &agg_func.params.args; + + // Only UserDefined signature functions need explicit coercion + if !matches!(func.signature().type_signature, TypeSignature::UserDefined) { + return Ok(expr.clone()); + } + + if args.is_empty() { + return Ok(expr.clone()); + } + + let current_types: Vec = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>()?; + + let coerced_types = func.coerce_types(¤t_types)?; + let coerced_args: Vec = args + .iter() + .zip(coerced_types.iter()) + .map(|(arg, target_type)| { + let arg_type = arg.get_type(schema)?; + if arg_type == *target_type { + Ok(arg.clone()) + } else { + arg.clone().cast_to(target_type, schema) + } + }) + .collect::>()?; + + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func.clone(), + coerced_args, + agg_func.params.distinct, + agg_func.params.filter.clone(), + agg_func.params.order_by.clone(), + agg_func.params.null_treatment, + ))) + } + Expr::Alias(alias) => { + // Recursively coerce the inner expression and preserve the alias + let coerced_inner = Self::coerce_aggregate_expr_impl(&alias.expr, schema)?; + Ok(coerced_inner.alias(&alias.name)) + } + other => Err(Error::invalid_input( + format!( + "Expected aggregate function expression, got {:?}", + other.variant_name() + ), + location!(), + )), + } + } + // A "narrow" field is a field that is so small that we are better off reading the // entire column and filtering in memory rather than "take"ing the column. // @@ -1852,6 +2237,25 @@ impl Scanner { }); } + if self.aggregate.is_some() { + if self.limit.is_some() || self.offset.is_some() { + return Err(Error::InvalidInput { + source: + "Cannot use limit/offset with aggregate. Apply limit to the result instead." + .into(), + location: location!(), + }); + } + if self.ordering.is_some() { + return Err(Error::InvalidInput { + source: + "Cannot use order_by with aggregate. Apply ordering to the result instead." + .into(), + location: location!(), + }); + } + } + Ok(()) } @@ -2008,7 +2412,7 @@ impl Scanner { let mut filter_plan = self.create_filter_plan(use_scalar_index).await?; let mut use_limit_node = true; - // Stage 1: source (either an (K|A)NN search, full text search or or a (full|indexed) scan) + // Source: either a (K|A)NN search, full text search, or a (full|indexed) scan let mut plan: Arc = match (&self.nearest, &self.full_text_query) { (Some(_), None) => self.vector_search_source(&mut filter_plan).await?, (None, Some(query)) => self.fts_search_source(&mut filter_plan, query).await?, @@ -2066,8 +2470,7 @@ impl Scanner { } }; - // Stage 1.5 load columns needed for stages 2 & 3 - // Calculate the schema needed for the filter and ordering. + // Load columns needed for filter and ordering let mut pre_filter_projection = self.dataset.empty_projection(); // We may need to take filter columns if we are going to refine @@ -2093,10 +2496,17 @@ impl Scanner { plan = self.take(plan, pre_filter_projection)?; - // Stage 2: filter + // Filter plan = filter_plan.refine_filter(plan, self).await?; - // Stage 3: sort + // Aggregate (if set, applies aggregate and returns early) + if let Some(agg_spec) = &self.aggregate { + // Take columns needed for aggregation + plan = self.take(plan, self.projection_plan.physical_projection.clone())?; + return self.apply_aggregate(plan, agg_spec).await; + } + + // Sort if let Some(ordering) = &self.ordering { let ordering_columns = ordering.iter().map(|col| &col.column_name); let projection_with_ordering = self @@ -2128,25 +2538,25 @@ impl Scanner { )); } - // Stage 4: limit / offset + // Limit / offset if use_limit_node && (self.limit.unwrap_or(0) > 0 || self.offset.is_some()) { plan = self.limit_node(plan); } - // Stage 5: take remaining columns required for projection + // Take remaining columns required for projection plan = self.take(plan, self.projection_plan.physical_projection.clone())?; - // Stage 6: Add system columns, if requested + // Add system columns, if requested if self.projection_plan.must_add_row_offset { plan = Arc::new(AddRowOffsetExec::try_new(plan, self.dataset.clone()).await?); } - // Stage 7: final projection + // Final projection let final_projection = self.calculate_final_projection(plan.schema().as_ref())?; plan = Arc::new(DFProjectionExec::try_new(final_projection, plan)?); - // Stage 8: If requested, apply a strict batch size to the final output + // If requested, apply a strict batch size to the final output if self.strict_batch_size { plan = Arc::new(StrictBatchSizeExec::new(plan, self.get_batch_size())); } @@ -2747,7 +3157,7 @@ impl Scanner { AggregateMode::Single, PhysicalGroupBy::new_single(group_expr), vec![Arc::new( - AggregateExprBuilder::new( + datafusion_physical_expr::aggregate::AggregateExprBuilder::new( functions_aggregate::min_max::max_udaf(), vec![expressions::col(SCORE_COL, &schema)?], ) diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs new file mode 100644 index 00000000000..e75595a78ce --- /dev/null +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -0,0 +1,1144 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Tests for Substrait aggregate + +use std::sync::Arc; + +use arrow_array::cast::AsArray; +use arrow_array::types::{Float64Type, Int64Type}; +use arrow_array::{ + FixedSizeListArray, Float32Array, Int64Array, RecordBatch, RecordBatchIterator, StringArray, +}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use datafusion_substrait::substrait::proto::{ + aggregate_function::AggregationInvocation, + aggregate_rel::{Grouping, Measure}, + expression::{ + field_reference::{ReferenceType, RootReference, RootType}, + reference_segment::{self, StructField}, + FieldReference, ReferenceSegment, RexType, + }, + extensions::{ + simple_extension_declaration::{ExtensionFunction, MappingType}, + SimpleExtensionDeclaration, SimpleExtensionUri, + }, + function_argument::ArgType, + rel::RelType, + sort_field::SortKind, + AggregateFunction, AggregateRel, Expression, FunctionArgument, Plan, PlanRel, Rel, RelRoot, + SortField, Version, +}; +use futures::TryStreamExt; +use lance_datafusion::exec::{execute_plan, LanceExecutionOptions}; +use lance_datagen::{array, gen_batch}; +use lance_table::format::Fragment; +use prost::Message; +use tempfile::tempdir; + +use crate::dataset::scanner::AggregateExpr; +use crate::index::vector::VectorIndexParams; +use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; +use crate::Dataset; +use lance_arrow::FixedSizeListArrayExt; +use lance_index::scalar::inverted::InvertedIndexParams; +use lance_index::scalar::FullTextSearchQuery; +use lance_index::{DatasetIndexExt, IndexType}; +use lance_linalg::distance::MetricType; + +/// Helper to create a field reference expression for a column index +fn field_ref(field_index: i32) -> Expression { + Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField(Box::new( + StructField { + field: field_index, + child: None, + }, + ))), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }))), + } +} + +/// Helper to create a Substrait AggregateRel with given measures and groupings +fn create_aggregate_rel( + measures: Vec, + grouping_expressions: Vec, + groupings: Vec, + extensions: Vec, + output_names: Vec, +) -> Vec { + let aggregate_rel = AggregateRel { + common: None, + input: None, // Input is ignored for pushdown + groupings, + measures, + grouping_expressions, + advanced_extension: None, + }; + + let rel = Rel { + rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))), + }; + + // Wrap in a Plan to include extensions + let plan = Plan { + version: Some(Version { + major_number: 0, + minor_number: 63, + patch_number: 0, + git_hash: String::new(), + producer: "lance-test".to_string(), + }), + #[allow(deprecated)] + extension_uris: vec![ + SimpleExtensionUri { + extension_uri_anchor: 1, + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml".to_string(), + }, + SimpleExtensionUri { + extension_uri_anchor: 2, + uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml".to_string(), + }, + ], + extensions, + relations: vec![PlanRel { + rel_type: Some(datafusion_substrait::substrait::proto::plan_rel::RelType::Root( + RelRoot { + input: Some(rel), + names: output_names, + }, + )), + }], + advanced_extensions: None, + expected_type_urls: vec![], + extension_urns: vec![], + parameter_bindings: vec![], + type_aliases: vec![], + }; + + plan.encode_to_vec() +} + +/// Create extension declaration for an aggregate function +fn agg_extension(anchor: u32, name: &str) -> SimpleExtensionDeclaration { + SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction { + #[allow(deprecated)] + extension_uri_reference: 1, + extension_urn_reference: 0, + function_anchor: anchor, + name: name.to_string(), + })), + } +} + +/// Create a COUNT(*) measure +fn count_star_measure(function_ref: u32) -> Measure { + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![], // COUNT(*) has no arguments + options: vec![], + output_type: None, + phase: 0, + sorts: vec![], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } +} + +/// Create a SUM/AVG/MIN/MAX measure on a column +fn simple_agg_measure(function_ref: u32, column_index: i32) -> Measure { + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(field_ref(column_index))), + }], + options: vec![], + output_type: None, + phase: 0, + sorts: vec![], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } +} + +/// Create an ordered aggregate measure (e.g., FIRST_VALUE with ORDER BY) +fn ordered_agg_measure( + function_ref: u32, + column_index: i32, + sort_column_index: i32, + ascending: bool, +) -> Measure { + use datafusion_substrait::substrait::proto::sort_field::SortDirection; + + let sort_direction = if ascending { + SortDirection::AscNullsLast + } else { + SortDirection::DescNullsLast + }; + + Measure { + measure: Some(AggregateFunction { + function_reference: function_ref, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(field_ref(column_index))), + }], + options: vec![], + output_type: None, + phase: 0, + sorts: vec![SortField { + expr: Some(field_ref(sort_column_index)), + sort_kind: Some(SortKind::Direction(sort_direction as i32)), + }], + invocation: AggregationInvocation::All as i32, + #[allow(deprecated)] + args: vec![], + }), + filter: None, + } +} + +/// Execute aggregate plan and collect results +async fn execute_aggregate( + dataset: &Dataset, + aggregate_bytes: &[u8], +) -> crate::Result> { + let mut scanner = dataset.scan(); + scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); + + let plan = scanner.create_plan().await?; + let stream = execute_plan(plan, LanceExecutionOptions::default())?; + stream.try_collect().await.map_err(|e| e.into()) +} + +/// Execute aggregate plan on specific fragments +async fn execute_aggregate_on_fragments( + dataset: &Dataset, + aggregate_bytes: &[u8], + fragments: Vec, +) -> crate::Result> { + let mut scanner = dataset.scan(); + scanner.with_fragments(fragments); + scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); + + let plan = scanner.create_plan().await?; + let stream = execute_plan(plan, LanceExecutionOptions::default())?; + stream.try_collect().await.map_err(|e| e.into()) +} + +/// Create a test dataset with numeric columns +async fn create_numeric_dataset(uri: &str, num_fragments: u32, rows_per_fragment: u32) -> Dataset { + gen_batch() + .col("x", array::step::()) + .col("y", array::step_custom::(0, 2)) + .col("category", array::cycle::(vec![1, 2, 3])) + .into_dataset( + uri, + FragmentCount::from(num_fragments), + FragmentRowCount::from(rows_per_fragment), + ) + .await + .unwrap() +} + +#[tokio::test] +async fn test_count_star_single_fragment() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.column(0).as_primitive::().value(0), 100); +} + +#[tokio::test] +async fn test_count_star_multiple_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 5, 100).await; + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + // 5 fragments * 100 rows = 500 total + assert_eq!(batch.column(0).as_primitive::().value(0), 500); +} + +#[tokio::test] +async fn test_count_star_subset_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 5, 100).await; + + // Get only first 2 fragments + let all_fragments = ds.get_fragments(); + let subset: Vec = all_fragments + .into_iter() + .take(2) + .map(|f| f.metadata) + .collect(); + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset) + .await + .unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + // 2 fragments * 100 rows = 200 total + assert_eq!(batch.column(0).as_primitive::().value(0), 200); +} + +#[tokio::test] +async fn test_sum_single_fragment() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + // SUM(x) where x = 0..99 + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], // column 0 = x + vec![], + vec![], + vec![agg_extension(1, "sum")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + // SUM(0..99) = 99*100/2 = 4950 + assert_eq!(batch.column(0).as_primitive::().value(0), 4950); +} + +#[tokio::test] +async fn test_sum_multiple_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 25).await; + + // SUM(x) where x = 0..99 across 4 fragments + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], + vec![], + vec![], + vec![agg_extension(1, "sum")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + // SUM(0..99) = 4950 + assert_eq!(batch.column(0).as_primitive::().value(0), 4950); +} + +#[tokio::test] +async fn test_min_max() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 25).await; + + // MIN(x) and MAX(x) + let agg_bytes = create_aggregate_rel( + vec![ + simple_agg_measure(1, 0), // MIN(x) + simple_agg_measure(2, 0), // MAX(x) + ], + vec![], + vec![], + vec![agg_extension(1, "min"), agg_extension(2, "max")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 2); + // MIN should be 0, MAX should be 99 + assert_eq!(batch.column(0).as_primitive::().value(0), 0); + assert_eq!(batch.column(1).as_primitive::().value(0), 99); +} + +#[tokio::test] +async fn test_avg() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 25).await; + + // AVG(x) where x = 0..99 + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], + vec![], + vec![], + vec![agg_extension(1, "avg")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + // AVG(0..99) = 49.5 + let avg = batch.column(0).as_primitive::().value(0); + assert!((avg - 49.5).abs() < 0.001); +} + +#[tokio::test] +async fn test_multiple_aggregates() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 25).await; + + // COUNT(*), SUM(x), MIN(x), MAX(x), AVG(x) + let agg_bytes = create_aggregate_rel( + vec![ + count_star_measure(1), + simple_agg_measure(2, 0), // SUM(x) + simple_agg_measure(3, 0), // MIN(x) + simple_agg_measure(4, 0), // MAX(x) + simple_agg_measure(5, 0), // AVG(x) + ], + vec![], + vec![], + vec![ + agg_extension(1, "count"), + agg_extension(2, "sum"), + agg_extension(3, "min"), + agg_extension(4, "max"), + agg_extension(5, "avg"), + ], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 5); + + // Verify all aggregates + assert_eq!(batch.column(0).as_primitive::().value(0), 100); // COUNT + assert_eq!(batch.column(1).as_primitive::().value(0), 4950); // SUM + assert_eq!(batch.column(2).as_primitive::().value(0), 0); // MIN + assert_eq!(batch.column(3).as_primitive::().value(0), 99); // MAX + let avg = batch.column(4).as_primitive::().value(0); + assert!((avg - 49.5).abs() < 0.001); // AVG +} + +#[tokio::test] +async fn test_group_by_with_count() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 4, 30).await; + + // COUNT(*) GROUP BY category + // category cycles through 1, 2, 3 + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![field_ref(2)], // category is column index 2 + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], // Reference to first grouping expression + }], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(batch.num_rows(), 3); // 3 categories + + // Each category should have 40 rows (120 total / 3 categories) + let counts: Vec = batch + .column(1) // count column + .as_primitive::() + .values() + .to_vec(); + + for count in counts { + assert_eq!(count, 40); + } +} + +#[tokio::test] +async fn test_group_by_with_sum() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 9).await; + + // SUM(x) GROUP BY category + // x = 0..8, category cycles 1,2,3,1,2,3,1,2,3 + // category 1: sum(0,3,6) = 9 + // category 2: sum(1,4,7) = 12 + // category 3: sum(2,5,8) = 15 + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], // SUM(x) + vec![field_ref(2)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "sum")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(batch.num_rows(), 3); // 3 categories + + // Collect results into a map for verification + let categories: Vec = batch + .column(0) // category column + .as_primitive::() + .values() + .to_vec(); + let sums: Vec = batch + .column(1) // sum column + .as_primitive::() + .values() + .to_vec(); + + let mut results_map = std::collections::HashMap::new(); + for (cat, sum) in categories.iter().zip(sums.iter()) { + results_map.insert(*cat, *sum); + } + + assert_eq!(results_map.get(&1), Some(&9)); + assert_eq!(results_map.get(&2), Some(&12)); + assert_eq!(results_map.get(&3), Some(&15)); +} + +#[tokio::test] +async fn test_aggregate_specific_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 10, 10).await; + + // Get fragments 3, 5, 7 (0-indexed) + let all_fragments = ds.get_fragments(); + let subset: Vec = all_fragments + .into_iter() + .enumerate() + .filter(|(i, _)| *i == 3 || *i == 5 || *i == 7) + .map(|(_, f)| f.metadata) + .collect(); + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + + let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset) + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + // 3 fragments * 10 rows = 30 total + assert_eq!(batch.column(0).as_primitive::().value(0), 30); +} + +#[tokio::test] +async fn test_sum_specific_fragments() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + // Create dataset where each fragment has distinct values + // Fragment 0: x = 0..9 (sum = 45) + // Fragment 1: x = 10..19 (sum = 145) + // Fragment 2: x = 20..29 (sum = 245) + // Fragment 3: x = 30..39 (sum = 345) + let ds = create_numeric_dataset(uri, 4, 10).await; + + // Only scan fragments 1 and 2 + let all_fragments = ds.get_fragments(); + let subset: Vec = all_fragments + .into_iter() + .enumerate() + .filter(|(i, _)| *i == 1 || *i == 2) + .map(|(_, f)| f.metadata) + .collect(); + + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], // SUM(x) + vec![], + vec![], + vec![agg_extension(1, "sum")], + vec![], + ); + + let results = execute_aggregate_on_fragments(&ds, &agg_bytes, subset) + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + // Fragment 1: sum(10..19) = 145 + // Fragment 2: sum(20..29) = 245 + // Total = 390 + assert_eq!(batch.column(0).as_primitive::().value(0), 390); +} + +#[tokio::test] +async fn test_aggregate_with_filter() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + let mut scanner = ds.scan(); + scanner.filter("x >= 50").unwrap(); + + let agg_bytes = create_aggregate_rel( + vec![ + count_star_measure(1), + simple_agg_measure(2, 0), // SUM(x) + simple_agg_measure(3, 0), // MIN(x) + simple_agg_measure(4, 0), // MAX(x) + ], + vec![], + vec![], + vec![ + agg_extension(1, "count"), + agg_extension(2, "sum"), + agg_extension(3, "min"), + agg_extension(4, "max"), + ], + vec![], + ); + scanner.aggregate(AggregateExpr::substrait(agg_bytes)); + + let plan = scanner.create_plan().await.unwrap(); + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let results: Vec = stream.try_collect().await.unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + + // Filter x >= 50 matches rows 50..99 (50 rows) + assert_eq!(batch.column(0).as_primitive::().value(0), 50); // COUNT + // SUM(50..99) = (50+99)*50/2 = 3725 + assert_eq!(batch.column(1).as_primitive::().value(0), 3725); // SUM + assert_eq!(batch.column(2).as_primitive::().value(0), 50); // MIN + assert_eq!(batch.column(3).as_primitive::().value(0), 99); // MAX +} + +#[tokio::test] +async fn test_aggregate_empty_result() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + // Apply filter that matches no rows, then aggregate + let mut scanner = ds.scan(); + scanner.project::<&str>(&[]).unwrap(); + scanner.with_row_id(); + scanner.filter("x > 1000").unwrap(); // No rows match + + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![], + vec![], + vec![agg_extension(1, "count")], + vec![], + ); + scanner.aggregate(AggregateExpr::substrait(agg_bytes)); + + let plan = scanner.create_plan().await.unwrap(); + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let results: Vec = stream.try_collect().await.unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), 1); + // COUNT(*) of empty result should be 0 + assert_eq!(batch.column(0).as_primitive::().value(0), 0); +} + +#[tokio::test] +async fn test_aggregate_single_row() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + // Create dataset with single row using Int64 to avoid type coercion issues + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "x", + DataType::Int64, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(arrow_array::Int64Array::from(vec![42]))], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + let ds = Dataset::write(reader, uri, None).await.unwrap(); + + let agg_bytes = create_aggregate_rel( + vec![ + count_star_measure(1), + simple_agg_measure(2, 0), // SUM(x) + simple_agg_measure(3, 0), // MIN(x) + simple_agg_measure(4, 0), // MAX(x) + ], + vec![], + vec![], + vec![ + agg_extension(1, "count"), + agg_extension(2, "sum"), + agg_extension(3, "min"), + agg_extension(4, "max"), + ], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + + assert_eq!(batch.column(0).as_primitive::().value(0), 1); // COUNT + assert_eq!(batch.column(1).as_primitive::().value(0), 42); // SUM + assert_eq!(batch.column(2).as_primitive::().value(0), 42); // MIN + assert_eq!(batch.column(3).as_primitive::().value(0), 42); // MAX +} + +#[tokio::test] +async fn test_aggregate_with_aliases() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 100).await; + + // COUNT(*), SUM(x), MIN(x) with custom aliases + let agg_bytes = create_aggregate_rel( + vec![ + count_star_measure(1), + simple_agg_measure(2, 0), + simple_agg_measure(3, 0), + ], + vec![], + vec![], + vec![ + agg_extension(1, "count"), + agg_extension(2, "sum"), + agg_extension(3, "min"), + ], + vec![ + "total_count".to_string(), + "sum_of_x".to_string(), + "min_x".to_string(), + ], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert_eq!(results.len(), 1); + let batch = &results[0]; + + // Verify output schema has the expected aliases + let schema = batch.schema(); + assert_eq!(schema.fields().len(), 3); + assert_eq!(schema.field(0).name(), "total_count"); + assert_eq!(schema.field(1).name(), "sum_of_x"); + assert_eq!(schema.field(2).name(), "min_x"); + + // Verify values are correct + assert_eq!(batch.column(0).as_primitive::().value(0), 100); + assert_eq!(batch.column(1).as_primitive::().value(0), 4950); + assert_eq!(batch.column(2).as_primitive::().value(0), 0); +} + +#[tokio::test] +async fn test_group_by_with_aliases() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 9).await; + + // SUM(x) GROUP BY category with aliases + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], + vec![field_ref(2)], + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "sum")], + vec!["group_key".to_string(), "total_sum".to_string()], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + + // Verify output schema has the expected aliases + let schema = batch.schema(); + assert_eq!(schema.fields().len(), 2); + assert_eq!(schema.field(0).name(), "group_key"); + assert_eq!(schema.field(1).name(), "total_sum"); +} + +#[tokio::test] +async fn test_first_value_with_order_by() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 9).await; + + // FIRST_VALUE(x) ORDER BY x ASC GROUP BY category + // x = 0..8, category cycles 1,2,3,1,2,3,1,2,3 + // category 1 has x values: 0, 3, 6 -> first_value(ORDER BY x ASC) = 0 + // category 2 has x values: 1, 4, 7 -> first_value(ORDER BY x ASC) = 1 + // category 3 has x values: 2, 5, 8 -> first_value(ORDER BY x ASC) = 2 + let agg_bytes = create_aggregate_rel( + vec![ordered_agg_measure(1, 0, 0, true)], // FIRST_VALUE(x) ORDER BY x ASC + vec![field_ref(2)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "first_value")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(batch.num_rows(), 3); + + let categories: Vec = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + let first_values: Vec = batch + .column(1) + .as_primitive::() + .values() + .to_vec(); + + let mut results_map = std::collections::HashMap::new(); + for (cat, val) in categories.iter().zip(first_values.iter()) { + results_map.insert(*cat, *val); + } + + assert_eq!(results_map.get(&1), Some(&0)); + assert_eq!(results_map.get(&2), Some(&1)); + assert_eq!(results_map.get(&3), Some(&2)); +} + +#[tokio::test] +async fn test_first_value_with_order_by_desc() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let ds = create_numeric_dataset(uri, 1, 9).await; + + // FIRST_VALUE(x) ORDER BY x DESC GROUP BY category + // category 1 has x values: 0, 3, 6 -> first_value(ORDER BY x DESC) = 6 + // category 2 has x values: 1, 4, 7 -> first_value(ORDER BY x DESC) = 7 + // category 3 has x values: 2, 5, 8 -> first_value(ORDER BY x DESC) = 8 + let agg_bytes = create_aggregate_rel( + vec![ordered_agg_measure(1, 0, 0, false)], // FIRST_VALUE(x) ORDER BY x DESC + vec![field_ref(2)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "first_value")], + vec![], + ); + + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); + assert!(!results.is_empty()); + + let batch = arrow::compute::concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(batch.num_rows(), 3); + + let categories: Vec = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + let first_values: Vec = batch + .column(1) + .as_primitive::() + .values() + .to_vec(); + + let mut results_map = std::collections::HashMap::new(); + for (cat, val) in categories.iter().zip(first_values.iter()) { + results_map.insert(*cat, *val); + } + + assert_eq!(results_map.get(&1), Some(&6)); + assert_eq!(results_map.get(&2), Some(&7)); + assert_eq!(results_map.get(&3), Some(&8)); +} + +/// Create a dataset with vectors, text, and category for vector search and FTS aggregate tests. +/// Schema: id (i64), vec (fixed_size_list[4]), text (utf8), category (utf8) +async fn create_vector_text_dataset(uri: &str, num_rows: i64) -> Dataset { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new( + "vec", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + true, + ), + Field::new("text", DataType::Utf8, false), + Field::new("category", DataType::Utf8, false), + ])); + + let ids: Vec = (0..num_rows).collect(); + let vectors: Vec = (0..num_rows).flat_map(|i| vec![i as f32; 4]).collect(); + let texts: Vec = (0..num_rows).map(|i| format!("document {}", i)).collect(); + let categories: Vec = (0..num_rows) + .map(|i| match i % 3 { + 0 => "category_a".to_string(), + 1 => "category_b".to_string(), + _ => "category_c".to_string(), + }) + .collect(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new( + FixedSizeListArray::try_new_from_values(Float32Array::from(vectors), 4).unwrap(), + ), + Arc::new(StringArray::from(texts)), + Arc::new(StringArray::from(categories)), + ], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + Dataset::write(reader, uri, None).await.unwrap() +} + +#[tokio::test] +async fn test_vector_search_with_aggregate() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create vector index + let params = VectorIndexParams::ivf_flat(2, MetricType::L2); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + // Vector search for top 30 results, then aggregate by category with COUNT(*) + // Query vector close to id=50 (vec=[50,50,50,50]) + let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]); + + // COUNT(*) GROUP BY category (column index 3) + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![field_ref(3)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "count")], + vec!["category".to_string(), "count".to_string()], + ); + + let mut scanner = dataset.scan(); + scanner + .nearest("vec", &query_vector, 30) + .unwrap() + .project(&["id", "category"]) + .unwrap() + .aggregate(AggregateExpr::substrait(agg_bytes)); + + let results = scanner.try_into_batch().await.unwrap(); + + // Should have 3 categories (or fewer if search results don't cover all) + assert!( + results.num_rows() >= 1 && results.num_rows() <= 3, + "Expected 1-3 rows but got {}", + results.num_rows() + ); + + // Total count should be 30 (top K results) + let counts: Vec = results + .column(1) + .as_primitive::() + .values() + .to_vec(); + let total: i64 = counts.iter().sum(); + assert_eq!(total, 30); +} + +#[tokio::test] +async fn test_fts_with_aggregate() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create FTS index on text column + dataset + .create_index( + &["text"], + IndexType::Inverted, + None, + &InvertedIndexParams::default(), + true, + ) + .await + .unwrap(); + + // FTS search for "document", then aggregate by category with COUNT(*) + // All documents match "document" so we should get all 100 rows + // COUNT(*) GROUP BY category (column index 3) + let agg_bytes = create_aggregate_rel( + vec![count_star_measure(1)], + vec![field_ref(3)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "count")], + vec!["category".to_string(), "count".to_string()], + ); + + let mut scanner = dataset.scan(); + scanner + .full_text_search(FullTextSearchQuery::new("document".to_string())) + .unwrap() + .project(&["id", "category"]) + .unwrap() + .aggregate(AggregateExpr::substrait(agg_bytes)); + + let results = scanner.try_into_batch().await.unwrap(); + + // Should have 3 categories + assert_eq!( + results.num_rows(), + 3, + "Expected 3 rows but got {}", + results.num_rows() + ); + + // Total count should be 100 (all documents match "document") + let counts: Vec = results + .column(1) + .as_primitive::() + .values() + .to_vec(); + let total: i64 = counts.iter().sum(); + assert_eq!(total, 100); + + // Each category should have ~33 rows (100/3) + for count in &counts { + assert!(*count >= 33 && *count <= 34); + } +} + +#[tokio::test] +async fn test_vector_search_with_sum_aggregate() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create vector index + let params = VectorIndexParams::ivf_flat(2, MetricType::L2); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + // Vector search for top 10 results, then SUM(id) GROUP BY category + let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]); + + // SUM(id) GROUP BY category + let agg_bytes = create_aggregate_rel( + vec![simple_agg_measure(1, 0)], // SUM(id) - column 0 + vec![field_ref(3)], // GROUP BY category + vec![Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references: vec![0], + }], + vec![agg_extension(1, "sum")], + vec!["category".to_string(), "sum_id".to_string()], + ); + + let mut scanner = dataset.scan(); + scanner + .nearest("vec", &query_vector, 10) + .unwrap() + .project(&["id", "category"]) + .unwrap() + .aggregate(AggregateExpr::substrait(agg_bytes)); + + let results = scanner.try_into_batch().await.unwrap(); + + // Should have results grouped by category (1-3 depending on which categories are in top K) + assert!( + results.num_rows() >= 1 && results.num_rows() <= 3, + "Expected 1-3 rows but got {}", + results.num_rows() + ); + + // Verify we have 2 columns: category and sum_id + assert_eq!(results.num_columns(), 2); +} diff --git a/rust/lance/src/dataset/tests/mod.rs b/rust/lance/src/dataset/tests/mod.rs index a1197dcc198..c0a6002debb 100644 --- a/rust/lance/src/dataset/tests/mod.rs +++ b/rust/lance/src/dataset/tests/mod.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +#[cfg(feature = "substrait")] +mod dataset_aggregate; mod dataset_common; mod dataset_concurrency_store; mod dataset_geo;