From a49a0914692068aa3b43867e4c085736ce1eed21 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Thu, 12 Feb 2026 21:28:09 -0800 Subject: [PATCH 1/8] refactor: use scanner aggregate for count_rows --- rust/lance/src/dataset/scanner.rs | 121 +---------- .../src/dataset/tests/dataset_aggregate.rs | 191 +++++++++++++++++- 2 files changed, 201 insertions(+), 111 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 8923bc03e99..901dfa71f9e 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -14,7 +14,6 @@ use async_recursion::async_recursion; use chrono::Utc; use datafusion::common::{exec_datafusion_err, DFSchema, JoinType, NullEquality, SchemaExt}; use datafusion::functions_aggregate; -use datafusion::functions_aggregate::count::count_udaf; use datafusion::logical_expr::{col, lit, Expr, ScalarUDF}; use datafusion::physical_expr::PhysicalSortExpr; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -24,7 +23,6 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, display::DisplayableExecutionPlan, - expressions::Literal, limit::GlobalLimitExec, repartition::RepartitionExec, union::UnionExec, @@ -1911,72 +1909,25 @@ impl Scanner { Ok(concat_batches(&schema, &batches)?) } - pub fn create_count_plan(&self) -> BoxFuture<'_, Result>> { + /// Scan and return the number of matching rows + /// + /// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method + /// especially if there is no filter. + #[instrument(skip_all)] + pub fn count_rows(&self) -> BoxFuture<'_, Result> { // Future intentionally boxed here to avoid large futures on the stack async move { - if self.projection_plan.physical_projection.is_empty() { - return Err(Error::invalid_input("count_rows called but with_row_id is false".to_string(), location!())); - } - if !self.projection_plan.physical_projection.is_metadata_only() { - let physical_schema = self.projection_plan.physical_projection.to_schema(); - let columns: Vec<&str> = physical_schema.fields - .iter() - .map(|field| field.name.as_str()) - .collect(); - - let msg = format!( - "count_rows should not be called on a plan selecting columns. selected columns: [{}]", - columns.join(", ") - ); - - return Err(Error::invalid_input(msg, location!())); - } - if self.limit.is_some() || self.offset.is_some() { log::warn!( "count_rows called with limit or offset which could have surprising results" ); } - let plan = self.create_plan().await?; - // Datafusion interprets COUNT(*) as COUNT(1) - let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); + let mut scanner = self.clone(); + scanner.aggregate(AggregateExpr::builder().count_star().build()); - let input_phy_exprs: &[Arc] = &[one]; - let schema = plan.schema(); - - let mut builder = datafusion_physical_expr::aggregate::AggregateExprBuilder::new( - count_udaf(), - input_phy_exprs.to_vec(), - ); - builder = builder.schema(schema); - builder = builder.alias("count_rows".to_string()); - - let count_expr = builder.build()?; - - let plan_schema = plan.schema(); - Ok(Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(Vec::new()), - vec![Arc::new(count_expr)], - vec![None], - plan, - plan_schema, - )?) as Arc) - } - .boxed() - } - - /// Scan and return the number of matching rows - /// - /// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method - /// especially if there is no filter. - #[instrument(skip_all)] - pub fn count_rows(&self) -> BoxFuture<'_, Result> { - // Future intentionally boxed here to avoid large futures on the stack - async move { - let count_plan = self.create_count_plan().await?; - let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; + let plan = scanner.create_plan().await?; + let mut stream = execute_plan(plan, LanceExecutionOptions::default())?; // A count plan will always return a single batch with a single row. if let Some(first_batch) = stream.next().await { @@ -1986,7 +1937,7 @@ impl Scanner { .as_any() .downcast_ref::() .ok_or(Error::invalid_input( - "Count plan did not return a UInt64Array".to_string(), + "Count plan did not return an Int64Array".to_string(), location!(), ))?; Ok(array.value(0) as u64) @@ -7373,56 +7324,6 @@ mod test { assert_plan_node_equals(exec_plan, expected).await } - #[tokio::test] - async fn test_count_plan() { - // A count rows operation should load the minimal amount of data - let dim = 256; - let fixture = TestVectorDataset::new_with_dimension(LanceFileVersion::Stable, true, dim) - .await - .unwrap(); - - // By default, all columns are returned, this is bad for a count_rows op - let err = fixture - .dataset - .scan() - .create_count_plan() - .await - .unwrap_err(); - assert!(matches!(err, Error::InvalidInput { .. })); - - let mut scan = fixture.dataset.scan(); - scan.project(&Vec::::default()).unwrap(); - - // with_row_id needs to be specified - let err = scan.create_count_plan().await.unwrap_err(); - assert!(matches!(err, Error::InvalidInput { .. })); - - scan.with_row_id(); - - let plan = scan.create_count_plan().await.unwrap(); - - assert_plan_node_equals( - plan, - "AggregateExec: mode=Single, gby=[], aggr=[count_rows] - LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=--, refine_filter=--", - ) - .await - .unwrap(); - - scan.filter("s == ''").unwrap(); - - let plan = scan.create_count_plan().await.unwrap(); - - assert_plan_node_equals( - plan, - "AggregateExec: mode=Single, gby=[], aggr=[count_rows] - ProjectionExec: expr=[_rowid@1 as _rowid] - LanceRead: uri=..., projection=[s], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=s = Utf8(\"\"), refine_filter=s = Utf8(\"\")", - ) - .await - .unwrap(); - } - #[tokio::test] async fn test_inexact_scalar_index_plans() { let data = gen_batch() diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index e75595a78ce..516ca1dbace 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -38,7 +38,7 @@ use tempfile::tempdir; use crate::dataset::scanner::AggregateExpr; use crate::index::vector::VectorIndexParams; -use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; +use crate::utils::test::{assert_plan_node_equals, DatagenExt, FragmentCount, FragmentRowCount}; use crate::Dataset; use lance_arrow::FixedSizeListArrayExt; use lance_index::scalar::inverted::InvertedIndexParams; @@ -1142,3 +1142,192 @@ async fn test_vector_search_with_sum_aggregate() { // Verify we have 2 columns: category and sum_id assert_eq!(results.num_columns(), 2); } + +// ============================================================================ +// Scanner::count_rows() tests +// ============================================================================ + +#[tokio::test] +async fn test_scanner_count_rows() { + let ds = create_numeric_dataset("memory://test_count_rows", 2, 50).await; + + // Check plan structure + let mut scanner = ds.scan(); + scanner.aggregate(AggregateExpr::builder().count_star().build()); + let plan = scanner.create_plan().await.unwrap(); + + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] + LanceRead: uri=..., projection=[x, y, category], num_fragments=2, range_before=None, range_after=None, row_id=false, row_addr=false, full_filter=--, refine_filter=--", + ) + .await + .unwrap(); + + // Execute and verify result + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 100 // 2 fragments * 50 rows + ); +} + +#[tokio::test] +async fn test_scanner_count_rows_with_filter() { + let ds = create_numeric_dataset("memory://test_count_rows_filter", 1, 100).await; + + // Check plan structure + let mut scanner = ds.scan(); + scanner.filter("x >= 50").unwrap(); + scanner.aggregate(AggregateExpr::builder().count_star().build()); + let plan = scanner.create_plan().await.unwrap(); + + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] + LanceRead: uri=..., projection=[x, y, category], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=x >= Int64(50), refine_filter=x >= Int64(50)", + ) + .await + .unwrap(); + + // Execute and verify result + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + // x ranges from 0 to 99, so x >= 50 matches rows 50..99 (50 rows) + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 50 + ); +} + +#[tokio::test] +async fn test_scanner_count_rows_empty_result() { + let ds = create_numeric_dataset("memory://test_count_rows_empty", 1, 100).await; + + let mut scanner = ds.scan(); + scanner.filter("x > 1000").unwrap(); // No rows match + let count = scanner.count_rows().await.unwrap(); + + assert_eq!(count, 0); +} + +#[tokio::test] +async fn test_scanner_count_rows_with_vector_search() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create vector index + let params = VectorIndexParams::ivf_flat(2, MetricType::L2); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]); + + // Check plan structure + let mut scanner = dataset.scan(); + scanner.nearest("vec", &query_vector, 30).unwrap(); + scanner.aggregate(AggregateExpr::builder().count_star().build()); + let plan = scanner.create_plan().await.unwrap(); + + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] + Take: ... + CoalesceBatchesExec: ... + SortExec: TopK(fetch=30), ... + ANNSubIndex: ... + ANNIvfPartition: ...deltas=1", + ) + .await + .unwrap(); + + // Execute and verify result + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 30 // top K results + ); +} + +#[tokio::test] +async fn test_scanner_count_rows_with_fts() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create FTS index on text column + dataset + .create_index( + &["text"], + IndexType::Inverted, + None, + &InvertedIndexParams::default(), + true, + ) + .await + .unwrap(); + + // Check plan structure + let mut scanner = dataset.scan(); + scanner + .full_text_search(FullTextSearchQuery::new("document".to_string())) + .unwrap(); + scanner.aggregate(AggregateExpr::builder().count_star().build()); + let plan = scanner.create_plan().await.unwrap(); + + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] + Take: ... + ... + MatchQuery: column=text, query=document", + ) + .await + .unwrap(); + + // Execute and verify result + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + // All 100 documents contain "document" + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 100 + ); +} + +#[tokio::test] +async fn test_scanner_count_rows_with_vector_search_and_filter() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create vector index + let params = VectorIndexParams::ivf_flat(2, MetricType::L2); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + // Vector search for top 50 results, then filter by category + let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]); + + let mut scanner = dataset.scan(); + scanner + .nearest("vec", &query_vector, 50) + .unwrap() + .filter("category = 'category_a'") + .unwrap(); + let count = scanner.count_rows().await.unwrap(); + + // Only ~1/3 of the top 50 results should be in category_a + assert!(count > 0 && count <= 50); +} From cb1bfafb5891815a27d8054179fbcdd5f42fb2ea Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Thu, 12 Feb 2026 21:31:53 -0800 Subject: [PATCH 2/8] cleanup --- rust/lance/src/dataset/tests/dataset_aggregate.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index 516ca1dbace..0b09b3fd5c8 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -1143,10 +1143,6 @@ async fn test_vector_search_with_sum_aggregate() { assert_eq!(results.num_columns(), 2); } -// ============================================================================ -// Scanner::count_rows() tests -// ============================================================================ - #[tokio::test] async fn test_scanner_count_rows() { let ds = create_numeric_dataset("memory://test_count_rows", 2, 50).await; From 7c33453a80085f415c2f6053961d709941e73346 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Thu, 12 Feb 2026 23:04:55 -0800 Subject: [PATCH 3/8] fix optimization --- rust/lance/src/dataset/scanner.rs | 90 +++++++++++++---- .../src/dataset/tests/dataset_aggregate.rs | 97 +++++++++++++++---- 2 files changed, 146 insertions(+), 41 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 901dfa71f9e..26ef42a6c6f 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -504,10 +504,13 @@ impl AggregateExpr { } } - fn to_aggregate( - &self, - #[allow(unused_variables)] schema: Arc, - ) -> Result { + /// Parse into a unified Aggregate structure. + /// + /// For Substrait, this parses the bytes into DataFusion expressions. + /// For DataFusion, this just wraps the expressions. + /// + /// The schema is used to resolve field references in Substrait expressions. + fn parse(self, #[allow(unused_variables)] schema: Arc) -> Result { match self { #[cfg(feature = "substrait")] Self::Substrait(bytes) => { @@ -515,7 +518,7 @@ impl AggregateExpr { use lance_datafusion::substrait::parse_substrait_aggregate; let ctx = get_session_context(&LanceExecutionOptions::default()); - parse_substrait_aggregate(bytes, schema, &ctx.state()) + parse_substrait_aggregate(&bytes, schema, &ctx.state()) .now_or_never() .expect("could not parse the Substrait aggregate in a synchronous fashion") } @@ -523,13 +526,28 @@ impl AggregateExpr { group_by, aggregates, } => Ok(Aggregate { - group_by: group_by.clone(), - aggregates: aggregates.clone(), + group_by, + aggregates, }), } } } +/// Returns the column names required by an aggregate. +/// +/// For COUNT(*) / count(1), this returns an empty set since it doesn't need any columns. +/// For other aggregates like SUM(x), COUNT(x), GROUP BY y, etc., this returns the +/// columns referenced. +fn aggregate_required_columns(agg: &Aggregate) -> Vec { + let mut columns = Vec::new(); + for expr in agg.group_by.iter().chain(agg.aggregates.iter()) { + columns.extend(Planner::column_names_in_expr(expr)); + } + columns.sort(); + columns.dedup(); + columns +} + /// Builder for creating aggregate expressions without using DataFusion or Substrait directly. /// /// The const generic `HAS_PENDING` tracks whether there's a pending aggregate that can be aliased. @@ -786,7 +804,7 @@ pub struct Scanner { /// File reader options to use when reading data files. file_reader_options: Option, - aggregate: Option, + aggregate: Option, // Legacy fields to help migrate some old projection behavior to new behavior // @@ -1237,9 +1255,14 @@ impl Scanner { } /// Set aggregation. - pub fn aggregate(&mut self, aggregate: AggregateExpr) -> &mut Self { - self.aggregate = Some(aggregate); - self + /// + /// The aggregate expression is parsed immediately using the dataset schema. + /// For Substrait aggregates, this converts them to DataFusion expressions. + pub fn aggregate(&mut self, aggregate: AggregateExpr) -> Result<&mut Self> { + let schema: Arc = Arc::new(self.dataset.schema().into()); + let parsed = aggregate.parse(schema)?; + self.aggregate = Some(parsed); + Ok(self) } /// Set the batch size. @@ -1924,7 +1947,7 @@ impl Scanner { } let mut scanner = self.clone(); - scanner.aggregate(AggregateExpr::builder().count_star().build()); + scanner.aggregate(AggregateExpr::builder().count_star().build())?; let plan = scanner.create_plan().await?; let mut stream = execute_plan(plan, LanceExecutionOptions::default())?; @@ -1969,12 +1992,11 @@ impl Scanner { async fn apply_aggregate( &self, plan: Arc, - agg_spec: &AggregateExpr, + agg: &Aggregate, ) -> Result> { use datafusion_physical_expr::aggregate::AggregateFunctionExpr; let schema = plan.schema(); - let agg = agg_spec.to_aggregate(schema.clone())?; let df_schema = DFSchema::try_from(schema.as_ref().clone())?; let group_exprs: Vec<(Arc, String)> = agg @@ -2451,10 +2473,19 @@ impl Scanner { plan = filter_plan.refine_filter(plan, self).await?; // Aggregate (if set, applies aggregate and returns early) - if let Some(agg_spec) = &self.aggregate { - // Take columns needed for aggregation - plan = self.take(plan, self.projection_plan.physical_projection.clone())?; - return self.apply_aggregate(plan, agg_spec).await; + if let Some(agg) = &self.aggregate { + // Take only columns needed by the aggregate, not the full projection. + // For COUNT(*), this is empty. For SUM(x), this is just [x]. + let agg_columns = aggregate_required_columns(agg); + let agg_projection = if agg_columns.is_empty() { + self.dataset.empty_projection() + } else { + self.dataset + .empty_projection() + .union_columns(&agg_columns, OnMissing::Error)? + }; + plan = self.take(plan, agg_projection)?; + return self.apply_aggregate(plan, agg).await; } // Sort @@ -2770,16 +2801,35 @@ impl Scanner { filter_plan: &mut ExprFilterPlan, ) -> Result { log::trace!("source is a filtered read"); + + // Compute the effective projection based on what's actually needed. + // If we have an aggregate, we only need the columns referenced by the aggregate, + // not all the columns from the projection plan. + let effective_projection = if let Some(agg) = &self.aggregate { + let agg_columns = aggregate_required_columns(agg); + if agg_columns.is_empty() { + // COUNT(*) or similar - no columns needed + self.dataset.empty_projection() + } else { + // Aggregate needs specific columns + self.dataset + .empty_projection() + .union_columns(&agg_columns, OnMissing::Error)? + } + } else { + self.projection_plan.physical_projection.clone() + }; + let mut projection = if filter_plan.has_refine() { // If the filter plan has two steps (a scalar indexed portion and a refine portion) then // it makes sense to grab cheap columns during the first step to avoid taking them for // the second step. - self.calc_eager_projection(filter_plan, &self.projection_plan.physical_projection)? + self.calc_eager_projection(filter_plan, &effective_projection)? .with_row_id() } else { // If the filter plan only has one step then we just do a filtered read of all the // columns that the user asked for. - self.projection_plan.physical_projection.clone() + effective_projection }; if projection.is_empty() { diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index 0b09b3fd5c8..9e43ced0fe5 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -216,7 +216,7 @@ async fn execute_aggregate( aggregate_bytes: &[u8], ) -> crate::Result> { let mut scanner = dataset.scan(); - scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); + scanner.aggregate(AggregateExpr::substrait(aggregate_bytes))?; let plan = scanner.create_plan().await?; let stream = execute_plan(plan, LanceExecutionOptions::default())?; @@ -231,7 +231,7 @@ async fn execute_aggregate_on_fragments( ) -> crate::Result> { let mut scanner = dataset.scan(); scanner.with_fragments(fragments); - scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); + scanner.aggregate(AggregateExpr::substrait(aggregate_bytes))?; let plan = scanner.create_plan().await?; let stream = execute_plan(plan, LanceExecutionOptions::default())?; @@ -267,6 +267,20 @@ async fn test_count_star_single_fragment() { vec![], ); + // Verify COUNT(*) has empty projection optimization + let mut scanner = ds.scan(); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes.clone())) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + assert_plan_node_equals( + plan, + "AggregateExec: mode=Single, gby=[], aggr=[count(...)] + LanceRead: uri=..., projection=[], num_fragments=1, range_before=None, range_after=None, row_id=false, row_addr=true, full_filter=--, refine_filter=--", + ) + .await + .unwrap(); + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); assert_eq!(results.len(), 1); let batch = &results[0]; @@ -343,6 +357,20 @@ async fn test_sum_single_fragment() { vec![], ); + // Verify SUM(x) only reads column x + let mut scanner = ds.scan(); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes.clone())) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + assert_plan_node_equals( + plan, + "AggregateExec: mode=Single, gby=[], aggr=[sum(...)] + LanceRead: uri=..., projection=[x], num_fragments=1, range_before=None, range_after=None, row_id=false, row_addr=false, full_filter=--, refine_filter=--", + ) + .await + .unwrap(); + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); assert_eq!(results.len(), 1); let batch = &results[0]; @@ -486,6 +514,20 @@ async fn test_group_by_with_count() { vec![], ); + // Verify GROUP BY category only reads category column + let mut scanner = ds.scan(); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes.clone())) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + assert_plan_node_equals( + plan, + "AggregateExec: mode=Single, gby=[category@0 as category], aggr=[count(...)] + LanceRead: uri=..., projection=[category], num_fragments=4, range_before=None, range_after=None, row_id=false, row_addr=false, full_filter=--, refine_filter=--", + ) + .await + .unwrap(); + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); assert!(!results.is_empty()); @@ -655,7 +697,9 @@ async fn test_aggregate_with_filter() { ], vec![], ); - scanner.aggregate(AggregateExpr::substrait(agg_bytes)); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let plan = scanner.create_plan().await.unwrap(); let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); @@ -692,7 +736,9 @@ async fn test_aggregate_empty_result() { vec![agg_extension(1, "count")], vec![], ); - scanner.aggregate(AggregateExpr::substrait(agg_bytes)); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let plan = scanner.create_plan().await.unwrap(); let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); @@ -1006,7 +1052,8 @@ async fn test_vector_search_with_aggregate() { .unwrap() .project(&["id", "category"]) .unwrap() - .aggregate(AggregateExpr::substrait(agg_bytes)); + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let results = scanner.try_into_batch().await.unwrap(); @@ -1066,7 +1113,8 @@ async fn test_fts_with_aggregate() { .unwrap() .project(&["id", "category"]) .unwrap() - .aggregate(AggregateExpr::substrait(agg_bytes)); + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let results = scanner.try_into_batch().await.unwrap(); @@ -1128,7 +1176,8 @@ async fn test_vector_search_with_sum_aggregate() { .unwrap() .project(&["id", "category"]) .unwrap() - .aggregate(AggregateExpr::substrait(agg_bytes)); + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let results = scanner.try_into_batch().await.unwrap(); @@ -1149,13 +1198,16 @@ async fn test_scanner_count_rows() { // Check plan structure let mut scanner = ds.scan(); - scanner.aggregate(AggregateExpr::builder().count_star().build()); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); let plan = scanner.create_plan().await.unwrap(); + // COUNT(*) should have empty projection (optimized to not read any columns) assert_plan_node_equals( plan.clone(), "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] - LanceRead: uri=..., projection=[x, y, category], num_fragments=2, range_before=None, range_after=None, row_id=false, row_addr=false, full_filter=--, refine_filter=--", + LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, row_id=false, row_addr=true, full_filter=--, refine_filter=--", ) .await .unwrap(); @@ -1177,13 +1229,16 @@ async fn test_scanner_count_rows_with_filter() { // Check plan structure let mut scanner = ds.scan(); scanner.filter("x >= 50").unwrap(); - scanner.aggregate(AggregateExpr::builder().count_star().build()); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); let plan = scanner.create_plan().await.unwrap(); + // COUNT(*) with filter: filter columns are needed, but no data columns for the aggregate assert_plan_node_equals( plan.clone(), "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] - LanceRead: uri=..., projection=[x, y, category], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=x >= Int64(50), refine_filter=x >= Int64(50)", + LanceRead: uri=..., projection=[x], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=x >= Int64(50), refine_filter=x >= Int64(50)", ) .await .unwrap(); @@ -1228,17 +1283,17 @@ async fn test_scanner_count_rows_with_vector_search() { // Check plan structure let mut scanner = dataset.scan(); scanner.nearest("vec", &query_vector, 30).unwrap(); - scanner.aggregate(AggregateExpr::builder().count_star().build()); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); let plan = scanner.create_plan().await.unwrap(); assert_plan_node_equals( plan.clone(), "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] - Take: ... - CoalesceBatchesExec: ... - SortExec: TopK(fetch=30), ... - ANNSubIndex: ... - ANNIvfPartition: ...deltas=1", + SortExec: TopK(fetch=30), ... + ANNSubIndex: ... + ANNIvfPartition: ...deltas=1", ) .await .unwrap(); @@ -1276,15 +1331,15 @@ async fn test_scanner_count_rows_with_fts() { scanner .full_text_search(FullTextSearchQuery::new("document".to_string())) .unwrap(); - scanner.aggregate(AggregateExpr::builder().count_star().build()); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); let plan = scanner.create_plan().await.unwrap(); assert_plan_node_equals( plan.clone(), "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] - Take: ... - ... - MatchQuery: column=text, query=document", + MatchQuery: column=text, query=document", ) .await .unwrap(); From 37c9cba295f97c0f7df30f1949532097a238b5ef Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Thu, 12 Feb 2026 23:07:53 -0800 Subject: [PATCH 4/8] fix test --- python/python/tests/test_dataset.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 33fb828a3ab..da426ff77f1 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -1168,19 +1168,10 @@ def test_count_rows_via_scanner(tmp_path: Path): ds = lance.write_dataset(pa.table({"a": range(100), "b": range(100)}), tmp_path) assert ds.scanner(filter="a < 50", columns=[], with_row_id=True).count_rows() == 50 - - with pytest.raises( - ValueError, match="should not be called on a plan selecting columns" - ): - ds.scanner(filter="a < 50", columns=["a"], with_row_id=True).count_rows() - - with pytest.raises( - ValueError, match="should not be called on a plan selecting columns" - ): - ds.scanner(with_row_id=True).count_rows() - - with pytest.raises(ValueError, match="with_row_id is false"): - ds.scanner(columns=[]).count_rows() + assert ds.scanner(filter="a < 50", columns=["a"], with_row_id=True).count_rows() == 50 + assert ds.scanner(with_row_id=True).count_rows() == 100 + assert ds.scanner(columns=[]).count_rows() == 100 + assert ds.scanner().count_rows() == 100 def test_select_none(tmp_path: Path): From fcccfc608435c6707936eecfc973071b636ee2db Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Thu, 12 Feb 2026 23:14:17 -0800 Subject: [PATCH 5/8] fix lint --- python/python/tests/test_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index da426ff77f1..4e0ef9f92c0 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -1168,7 +1168,9 @@ def test_count_rows_via_scanner(tmp_path: Path): ds = lance.write_dataset(pa.table({"a": range(100), "b": range(100)}), tmp_path) assert ds.scanner(filter="a < 50", columns=[], with_row_id=True).count_rows() == 50 - assert ds.scanner(filter="a < 50", columns=["a"], with_row_id=True).count_rows() == 50 + assert ( + ds.scanner(filter="a < 50", columns=["a"], with_row_id=True).count_rows() == 50 + ) assert ds.scanner(with_row_id=True).count_rows() == 100 assert ds.scanner(columns=[]).count_rows() == 100 assert ds.scanner().count_rows() == 100 From 552edaeef7a44fd14722719e0f392f91c5b8e869 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Thu, 12 Feb 2026 23:28:43 -0800 Subject: [PATCH 6/8] avoid duplicated call to aggregate_required_columns --- rust/lance-datafusion/src/aggregate.rs | 22 +++++++++++++++++++ rust/lance-datafusion/src/substrait.rs | 5 +---- rust/lance/src/dataset/scanner.rs | 30 +++++--------------------- 3 files changed, 28 insertions(+), 29 deletions(-) diff --git a/rust/lance-datafusion/src/aggregate.rs b/rust/lance-datafusion/src/aggregate.rs index 5528104c044..cb692904833 100644 --- a/rust/lance-datafusion/src/aggregate.rs +++ b/rust/lance-datafusion/src/aggregate.rs @@ -5,6 +5,8 @@ use datafusion::logical_expr::Expr; +use crate::planner::Planner; + /// Aggregate specification with group by and aggregate expressions. #[derive(Debug, Clone)] pub struct Aggregate { @@ -13,4 +15,24 @@ pub struct Aggregate { /// Aggregate function expressions (e.g., SUM, COUNT, AVG). /// Use `.alias()` on the expression to set output column names. pub aggregates: Vec, + /// Column names required by this aggregate (computed at construction). + /// For COUNT(*), this is empty. For SUM(x), GROUP BY y, this contains [x, y]. + pub required_columns: Vec, +} + +impl Aggregate { + /// Create a new Aggregate, computing required columns from the expressions. + pub fn new(group_by: Vec, aggregates: Vec) -> Self { + let mut required_columns = Vec::new(); + for expr in group_by.iter().chain(aggregates.iter()) { + required_columns.extend(Planner::column_names_in_expr(expr)); + } + required_columns.sort(); + required_columns.dedup(); + Self { + group_by, + aggregates, + required_columns, + } + } } diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index 2f84a266f65..295da39d09c 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -414,10 +414,7 @@ pub async fn parse_aggregate_rel_with_extensions( let group_by = parse_groupings(aggregate_rel, &df_schema, &consumer).await?; let aggregates = parse_measures(aggregate_rel, &df_schema, &consumer).await?; - Ok(Aggregate { - group_by, - aggregates, - }) + Ok(Aggregate::new(group_by, aggregates)) } /// Parse an AggregateRel proto with default extensions. diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 26ef42a6c6f..f977057d19b 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -525,29 +525,11 @@ impl AggregateExpr { Self::Datafusion { group_by, aggregates, - } => Ok(Aggregate { - group_by, - aggregates, - }), + } => Ok(Aggregate::new(group_by, aggregates)), } } } -/// Returns the column names required by an aggregate. -/// -/// For COUNT(*) / count(1), this returns an empty set since it doesn't need any columns. -/// For other aggregates like SUM(x), COUNT(x), GROUP BY y, etc., this returns the -/// columns referenced. -fn aggregate_required_columns(agg: &Aggregate) -> Vec { - let mut columns = Vec::new(); - for expr in agg.group_by.iter().chain(agg.aggregates.iter()) { - columns.extend(Planner::column_names_in_expr(expr)); - } - columns.sort(); - columns.dedup(); - columns -} - /// Builder for creating aggregate expressions without using DataFusion or Substrait directly. /// /// The const generic `HAS_PENDING` tracks whether there's a pending aggregate that can be aliased. @@ -2476,13 +2458,12 @@ impl Scanner { if let Some(agg) = &self.aggregate { // Take only columns needed by the aggregate, not the full projection. // For COUNT(*), this is empty. For SUM(x), this is just [x]. - let agg_columns = aggregate_required_columns(agg); - let agg_projection = if agg_columns.is_empty() { + let agg_projection = if agg.required_columns.is_empty() { self.dataset.empty_projection() } else { self.dataset .empty_projection() - .union_columns(&agg_columns, OnMissing::Error)? + .union_columns(&agg.required_columns, OnMissing::Error)? }; plan = self.take(plan, agg_projection)?; return self.apply_aggregate(plan, agg).await; @@ -2806,15 +2787,14 @@ impl Scanner { // If we have an aggregate, we only need the columns referenced by the aggregate, // not all the columns from the projection plan. let effective_projection = if let Some(agg) = &self.aggregate { - let agg_columns = aggregate_required_columns(agg); - if agg_columns.is_empty() { + if agg.required_columns.is_empty() { // COUNT(*) or similar - no columns needed self.dataset.empty_projection() } else { // Aggregate needs specific columns self.dataset .empty_projection() - .union_columns(&agg_columns, OnMissing::Error)? + .union_columns(&agg.required_columns, OnMissing::Error)? } } else { self.projection_plan.physical_projection.clone() From 009df298e595d18e84b9b5b89b2c8d2bbbec690c Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Fri, 13 Feb 2026 13:09:41 -0800 Subject: [PATCH 7/8] address comments --- rust/lance-datafusion/src/aggregate.rs | 23 +++++++++++++---------- rust/lance/src/dataset/scanner.rs | 16 ++++++---------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/rust/lance-datafusion/src/aggregate.rs b/rust/lance-datafusion/src/aggregate.rs index cb692904833..3b4ee96b719 100644 --- a/rust/lance-datafusion/src/aggregate.rs +++ b/rust/lance-datafusion/src/aggregate.rs @@ -15,24 +15,27 @@ pub struct Aggregate { /// Aggregate function expressions (e.g., SUM, COUNT, AVG). /// Use `.alias()` on the expression to set output column names. pub aggregates: Vec, - /// Column names required by this aggregate (computed at construction). - /// For COUNT(*), this is empty. For SUM(x), GROUP BY y, this contains [x, y]. - pub required_columns: Vec, } impl Aggregate { - /// Create a new Aggregate, computing required columns from the expressions. + /// Create a new Aggregate. pub fn new(group_by: Vec, aggregates: Vec) -> Self { + Self { + group_by, + aggregates, + } + } + + /// Compute column names required by this aggregate. + /// + /// For COUNT(*), this returns empty. For SUM(x), GROUP BY y, this returns [x, y]. + pub fn required_columns(&self) -> Vec { let mut required_columns = Vec::new(); - for expr in group_by.iter().chain(aggregates.iter()) { + for expr in self.group_by.iter().chain(self.aggregates.iter()) { required_columns.extend(Planner::column_names_in_expr(expr)); } required_columns.sort(); required_columns.dedup(); - Self { - group_by, - aggregates, - required_columns, - } + required_columns } } diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index f977057d19b..f1d4ee6da17 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1922,12 +1922,6 @@ impl Scanner { pub fn count_rows(&self) -> BoxFuture<'_, Result> { // Future intentionally boxed here to avoid large futures on the stack async move { - if self.limit.is_some() || self.offset.is_some() { - log::warn!( - "count_rows called with limit or offset which could have surprising results" - ); - } - let mut scanner = self.clone(); scanner.aggregate(AggregateExpr::builder().count_star().build())?; @@ -2458,12 +2452,13 @@ impl Scanner { if let Some(agg) = &self.aggregate { // Take only columns needed by the aggregate, not the full projection. // For COUNT(*), this is empty. For SUM(x), this is just [x]. - let agg_projection = if agg.required_columns.is_empty() { + let required_columns = agg.required_columns(); + let agg_projection = if required_columns.is_empty() { self.dataset.empty_projection() } else { self.dataset .empty_projection() - .union_columns(&agg.required_columns, OnMissing::Error)? + .union_columns(&required_columns, OnMissing::Error)? }; plan = self.take(plan, agg_projection)?; return self.apply_aggregate(plan, agg).await; @@ -2787,14 +2782,15 @@ impl Scanner { // If we have an aggregate, we only need the columns referenced by the aggregate, // not all the columns from the projection plan. let effective_projection = if let Some(agg) = &self.aggregate { - if agg.required_columns.is_empty() { + let required_columns = agg.required_columns(); + if required_columns.is_empty() { // COUNT(*) or similar - no columns needed self.dataset.empty_projection() } else { // Aggregate needs specific columns self.dataset .empty_projection() - .union_columns(&agg.required_columns, OnMissing::Error)? + .union_columns(&required_columns, OnMissing::Error)? } } else { self.projection_plan.physical_projection.clone() From 37be88612355a6bfc1abfd83bcaa65d2d8583a78 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Fri, 13 Feb 2026 14:31:11 -0800 Subject: [PATCH 8/8] fix clippy --- java/lance-jni/src/blocking_scanner.rs | 2 +- python/src/dataset.rs | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index b9b87e1319a..262cbcb6489 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -381,7 +381,7 @@ fn inner_create_scanner<'local>( let substrait_aggregate_opt = env.get_bytes_opt(&substrait_aggregate_obj)?; if let Some(substrait_aggregate) = substrait_aggregate_opt { - scanner.aggregate(AggregateExpr::substrait(substrait_aggregate)); + scanner.aggregate(AggregateExpr::substrait(substrait_aggregate))?; } let scanner = BlockingScanner::create(scanner); diff --git a/python/src/dataset.rs b/python/src/dataset.rs index a453e38729a..f180a5dd145 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1094,7 +1094,9 @@ impl Dataset { .map_err(|err| PyValueError::new_err(err.to_string()))?; } if let Some(aggregate_bytes) = substrait_aggregate { - scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); + scanner + .aggregate(AggregateExpr::substrait(aggregate_bytes)) + .map_err(|err| PyValueError::new_err(err.to_string()))?; } let scan = Arc::new(scanner); Ok(Scanner::new(scan))