diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index b9b87e1319a..262cbcb6489 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -381,7 +381,7 @@ fn inner_create_scanner<'local>( let substrait_aggregate_opt = env.get_bytes_opt(&substrait_aggregate_obj)?; if let Some(substrait_aggregate) = substrait_aggregate_opt { - scanner.aggregate(AggregateExpr::substrait(substrait_aggregate)); + scanner.aggregate(AggregateExpr::substrait(substrait_aggregate))?; } let scanner = BlockingScanner::create(scanner); diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 33fb828a3ab..4e0ef9f92c0 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -1168,19 +1168,12 @@ def test_count_rows_via_scanner(tmp_path: Path): ds = lance.write_dataset(pa.table({"a": range(100), "b": range(100)}), tmp_path) assert ds.scanner(filter="a < 50", columns=[], with_row_id=True).count_rows() == 50 - - with pytest.raises( - ValueError, match="should not be called on a plan selecting columns" - ): - ds.scanner(filter="a < 50", columns=["a"], with_row_id=True).count_rows() - - with pytest.raises( - ValueError, match="should not be called on a plan selecting columns" - ): - ds.scanner(with_row_id=True).count_rows() - - with pytest.raises(ValueError, match="with_row_id is false"): - ds.scanner(columns=[]).count_rows() + assert ( + ds.scanner(filter="a < 50", columns=["a"], with_row_id=True).count_rows() == 50 + ) + assert ds.scanner(with_row_id=True).count_rows() == 100 + assert ds.scanner(columns=[]).count_rows() == 100 + assert ds.scanner().count_rows() == 100 def test_select_none(tmp_path: Path): diff --git a/python/src/dataset.rs b/python/src/dataset.rs index a453e38729a..f180a5dd145 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1094,7 +1094,9 @@ impl Dataset { .map_err(|err| PyValueError::new_err(err.to_string()))?; } if let Some(aggregate_bytes) = substrait_aggregate { - scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); + scanner + .aggregate(AggregateExpr::substrait(aggregate_bytes)) + .map_err(|err| PyValueError::new_err(err.to_string()))?; } let scan = Arc::new(scanner); Ok(Scanner::new(scan)) diff --git a/rust/lance-datafusion/src/aggregate.rs b/rust/lance-datafusion/src/aggregate.rs index 5528104c044..3b4ee96b719 100644 --- a/rust/lance-datafusion/src/aggregate.rs +++ b/rust/lance-datafusion/src/aggregate.rs @@ -5,6 +5,8 @@ use datafusion::logical_expr::Expr; +use crate::planner::Planner; + /// Aggregate specification with group by and aggregate expressions. #[derive(Debug, Clone)] pub struct Aggregate { @@ -14,3 +16,26 @@ pub struct Aggregate { /// Use `.alias()` on the expression to set output column names. pub aggregates: Vec, } + +impl Aggregate { + /// Create a new Aggregate. + pub fn new(group_by: Vec, aggregates: Vec) -> Self { + Self { + group_by, + aggregates, + } + } + + /// Compute column names required by this aggregate. + /// + /// For COUNT(*), this returns empty. For SUM(x), GROUP BY y, this returns [x, y]. + pub fn required_columns(&self) -> Vec { + let mut required_columns = Vec::new(); + for expr in self.group_by.iter().chain(self.aggregates.iter()) { + required_columns.extend(Planner::column_names_in_expr(expr)); + } + required_columns.sort(); + required_columns.dedup(); + required_columns + } +} diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index 2f84a266f65..295da39d09c 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -414,10 +414,7 @@ pub async fn parse_aggregate_rel_with_extensions( let group_by = parse_groupings(aggregate_rel, &df_schema, &consumer).await?; let aggregates = parse_measures(aggregate_rel, &df_schema, &consumer).await?; - Ok(Aggregate { - group_by, - aggregates, - }) + Ok(Aggregate::new(group_by, aggregates)) } /// Parse an AggregateRel proto with default extensions. diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 8923bc03e99..f1d4ee6da17 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -14,7 +14,6 @@ use async_recursion::async_recursion; use chrono::Utc; use datafusion::common::{exec_datafusion_err, DFSchema, JoinType, NullEquality, SchemaExt}; use datafusion::functions_aggregate; -use datafusion::functions_aggregate::count::count_udaf; use datafusion::logical_expr::{col, lit, Expr, ScalarUDF}; use datafusion::physical_expr::PhysicalSortExpr; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -24,7 +23,6 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, display::DisplayableExecutionPlan, - expressions::Literal, limit::GlobalLimitExec, repartition::RepartitionExec, union::UnionExec, @@ -506,10 +504,13 @@ impl AggregateExpr { } } - fn to_aggregate( - &self, - #[allow(unused_variables)] schema: Arc, - ) -> Result { + /// Parse into a unified Aggregate structure. + /// + /// For Substrait, this parses the bytes into DataFusion expressions. + /// For DataFusion, this just wraps the expressions. + /// + /// The schema is used to resolve field references in Substrait expressions. + fn parse(self, #[allow(unused_variables)] schema: Arc) -> Result { match self { #[cfg(feature = "substrait")] Self::Substrait(bytes) => { @@ -517,17 +518,14 @@ impl AggregateExpr { use lance_datafusion::substrait::parse_substrait_aggregate; let ctx = get_session_context(&LanceExecutionOptions::default()); - parse_substrait_aggregate(bytes, schema, &ctx.state()) + parse_substrait_aggregate(&bytes, schema, &ctx.state()) .now_or_never() .expect("could not parse the Substrait aggregate in a synchronous fashion") } Self::Datafusion { group_by, aggregates, - } => Ok(Aggregate { - group_by: group_by.clone(), - aggregates: aggregates.clone(), - }), + } => Ok(Aggregate::new(group_by, aggregates)), } } } @@ -788,7 +786,7 @@ pub struct Scanner { /// File reader options to use when reading data files. file_reader_options: Option, - aggregate: Option, + aggregate: Option, // Legacy fields to help migrate some old projection behavior to new behavior // @@ -1239,9 +1237,14 @@ impl Scanner { } /// Set aggregation. - pub fn aggregate(&mut self, aggregate: AggregateExpr) -> &mut Self { - self.aggregate = Some(aggregate); - self + /// + /// The aggregate expression is parsed immediately using the dataset schema. + /// For Substrait aggregates, this converts them to DataFusion expressions. + pub fn aggregate(&mut self, aggregate: AggregateExpr) -> Result<&mut Self> { + let schema: Arc = Arc::new(self.dataset.schema().into()); + let parsed = aggregate.parse(schema)?; + self.aggregate = Some(parsed); + Ok(self) } /// Set the batch size. @@ -1911,62 +1914,6 @@ impl Scanner { Ok(concat_batches(&schema, &batches)?) } - pub fn create_count_plan(&self) -> BoxFuture<'_, Result>> { - // Future intentionally boxed here to avoid large futures on the stack - async move { - if self.projection_plan.physical_projection.is_empty() { - return Err(Error::invalid_input("count_rows called but with_row_id is false".to_string(), location!())); - } - if !self.projection_plan.physical_projection.is_metadata_only() { - let physical_schema = self.projection_plan.physical_projection.to_schema(); - let columns: Vec<&str> = physical_schema.fields - .iter() - .map(|field| field.name.as_str()) - .collect(); - - let msg = format!( - "count_rows should not be called on a plan selecting columns. selected columns: [{}]", - columns.join(", ") - ); - - return Err(Error::invalid_input(msg, location!())); - } - - if self.limit.is_some() || self.offset.is_some() { - log::warn!( - "count_rows called with limit or offset which could have surprising results" - ); - } - - let plan = self.create_plan().await?; - // Datafusion interprets COUNT(*) as COUNT(1) - let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); - - let input_phy_exprs: &[Arc] = &[one]; - let schema = plan.schema(); - - let mut builder = datafusion_physical_expr::aggregate::AggregateExprBuilder::new( - count_udaf(), - input_phy_exprs.to_vec(), - ); - builder = builder.schema(schema); - builder = builder.alias("count_rows".to_string()); - - let count_expr = builder.build()?; - - let plan_schema = plan.schema(); - Ok(Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(Vec::new()), - vec![Arc::new(count_expr)], - vec![None], - plan, - plan_schema, - )?) as Arc) - } - .boxed() - } - /// Scan and return the number of matching rows /// /// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method @@ -1975,8 +1922,11 @@ impl Scanner { pub fn count_rows(&self) -> BoxFuture<'_, Result> { // Future intentionally boxed here to avoid large futures on the stack async move { - let count_plan = self.create_count_plan().await?; - let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; + let mut scanner = self.clone(); + scanner.aggregate(AggregateExpr::builder().count_star().build())?; + + let plan = scanner.create_plan().await?; + let mut stream = execute_plan(plan, LanceExecutionOptions::default())?; // A count plan will always return a single batch with a single row. if let Some(first_batch) = stream.next().await { @@ -1986,7 +1936,7 @@ impl Scanner { .as_any() .downcast_ref::() .ok_or(Error::invalid_input( - "Count plan did not return a UInt64Array".to_string(), + "Count plan did not return an Int64Array".to_string(), location!(), ))?; Ok(array.value(0) as u64) @@ -2018,12 +1968,11 @@ impl Scanner { async fn apply_aggregate( &self, plan: Arc, - agg_spec: &AggregateExpr, + agg: &Aggregate, ) -> Result> { use datafusion_physical_expr::aggregate::AggregateFunctionExpr; let schema = plan.schema(); - let agg = agg_spec.to_aggregate(schema.clone())?; let df_schema = DFSchema::try_from(schema.as_ref().clone())?; let group_exprs: Vec<(Arc, String)> = agg @@ -2500,10 +2449,19 @@ impl Scanner { plan = filter_plan.refine_filter(plan, self).await?; // Aggregate (if set, applies aggregate and returns early) - if let Some(agg_spec) = &self.aggregate { - // Take columns needed for aggregation - plan = self.take(plan, self.projection_plan.physical_projection.clone())?; - return self.apply_aggregate(plan, agg_spec).await; + if let Some(agg) = &self.aggregate { + // Take only columns needed by the aggregate, not the full projection. + // For COUNT(*), this is empty. For SUM(x), this is just [x]. + let required_columns = agg.required_columns(); + let agg_projection = if required_columns.is_empty() { + self.dataset.empty_projection() + } else { + self.dataset + .empty_projection() + .union_columns(&required_columns, OnMissing::Error)? + }; + plan = self.take(plan, agg_projection)?; + return self.apply_aggregate(plan, agg).await; } // Sort @@ -2819,16 +2777,35 @@ impl Scanner { filter_plan: &mut ExprFilterPlan, ) -> Result { log::trace!("source is a filtered read"); + + // Compute the effective projection based on what's actually needed. + // If we have an aggregate, we only need the columns referenced by the aggregate, + // not all the columns from the projection plan. + let effective_projection = if let Some(agg) = &self.aggregate { + let required_columns = agg.required_columns(); + if required_columns.is_empty() { + // COUNT(*) or similar - no columns needed + self.dataset.empty_projection() + } else { + // Aggregate needs specific columns + self.dataset + .empty_projection() + .union_columns(&required_columns, OnMissing::Error)? + } + } else { + self.projection_plan.physical_projection.clone() + }; + let mut projection = if filter_plan.has_refine() { // If the filter plan has two steps (a scalar indexed portion and a refine portion) then // it makes sense to grab cheap columns during the first step to avoid taking them for // the second step. - self.calc_eager_projection(filter_plan, &self.projection_plan.physical_projection)? + self.calc_eager_projection(filter_plan, &effective_projection)? .with_row_id() } else { // If the filter plan only has one step then we just do a filtered read of all the // columns that the user asked for. - self.projection_plan.physical_projection.clone() + effective_projection }; if projection.is_empty() { @@ -7373,56 +7350,6 @@ mod test { assert_plan_node_equals(exec_plan, expected).await } - #[tokio::test] - async fn test_count_plan() { - // A count rows operation should load the minimal amount of data - let dim = 256; - let fixture = TestVectorDataset::new_with_dimension(LanceFileVersion::Stable, true, dim) - .await - .unwrap(); - - // By default, all columns are returned, this is bad for a count_rows op - let err = fixture - .dataset - .scan() - .create_count_plan() - .await - .unwrap_err(); - assert!(matches!(err, Error::InvalidInput { .. })); - - let mut scan = fixture.dataset.scan(); - scan.project(&Vec::::default()).unwrap(); - - // with_row_id needs to be specified - let err = scan.create_count_plan().await.unwrap_err(); - assert!(matches!(err, Error::InvalidInput { .. })); - - scan.with_row_id(); - - let plan = scan.create_count_plan().await.unwrap(); - - assert_plan_node_equals( - plan, - "AggregateExec: mode=Single, gby=[], aggr=[count_rows] - LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=--, refine_filter=--", - ) - .await - .unwrap(); - - scan.filter("s == ''").unwrap(); - - let plan = scan.create_count_plan().await.unwrap(); - - assert_plan_node_equals( - plan, - "AggregateExec: mode=Single, gby=[], aggr=[count_rows] - ProjectionExec: expr=[_rowid@1 as _rowid] - LanceRead: uri=..., projection=[s], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=s = Utf8(\"\"), refine_filter=s = Utf8(\"\")", - ) - .await - .unwrap(); - } - #[tokio::test] async fn test_inexact_scalar_index_plans() { let data = gen_batch() diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index e75595a78ce..9e43ced0fe5 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -38,7 +38,7 @@ use tempfile::tempdir; use crate::dataset::scanner::AggregateExpr; use crate::index::vector::VectorIndexParams; -use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; +use crate::utils::test::{assert_plan_node_equals, DatagenExt, FragmentCount, FragmentRowCount}; use crate::Dataset; use lance_arrow::FixedSizeListArrayExt; use lance_index::scalar::inverted::InvertedIndexParams; @@ -216,7 +216,7 @@ async fn execute_aggregate( aggregate_bytes: &[u8], ) -> crate::Result> { let mut scanner = dataset.scan(); - scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); + scanner.aggregate(AggregateExpr::substrait(aggregate_bytes))?; let plan = scanner.create_plan().await?; let stream = execute_plan(plan, LanceExecutionOptions::default())?; @@ -231,7 +231,7 @@ async fn execute_aggregate_on_fragments( ) -> crate::Result> { let mut scanner = dataset.scan(); scanner.with_fragments(fragments); - scanner.aggregate(AggregateExpr::substrait(aggregate_bytes)); + scanner.aggregate(AggregateExpr::substrait(aggregate_bytes))?; let plan = scanner.create_plan().await?; let stream = execute_plan(plan, LanceExecutionOptions::default())?; @@ -267,6 +267,20 @@ async fn test_count_star_single_fragment() { vec![], ); + // Verify COUNT(*) has empty projection optimization + let mut scanner = ds.scan(); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes.clone())) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + assert_plan_node_equals( + plan, + "AggregateExec: mode=Single, gby=[], aggr=[count(...)] + LanceRead: uri=..., projection=[], num_fragments=1, range_before=None, range_after=None, row_id=false, row_addr=true, full_filter=--, refine_filter=--", + ) + .await + .unwrap(); + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); assert_eq!(results.len(), 1); let batch = &results[0]; @@ -343,6 +357,20 @@ async fn test_sum_single_fragment() { vec![], ); + // Verify SUM(x) only reads column x + let mut scanner = ds.scan(); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes.clone())) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + assert_plan_node_equals( + plan, + "AggregateExec: mode=Single, gby=[], aggr=[sum(...)] + LanceRead: uri=..., projection=[x], num_fragments=1, range_before=None, range_after=None, row_id=false, row_addr=false, full_filter=--, refine_filter=--", + ) + .await + .unwrap(); + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); assert_eq!(results.len(), 1); let batch = &results[0]; @@ -486,6 +514,20 @@ async fn test_group_by_with_count() { vec![], ); + // Verify GROUP BY category only reads category column + let mut scanner = ds.scan(); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes.clone())) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + assert_plan_node_equals( + plan, + "AggregateExec: mode=Single, gby=[category@0 as category], aggr=[count(...)] + LanceRead: uri=..., projection=[category], num_fragments=4, range_before=None, range_after=None, row_id=false, row_addr=false, full_filter=--, refine_filter=--", + ) + .await + .unwrap(); + let results = execute_aggregate(&ds, &agg_bytes).await.unwrap(); assert!(!results.is_empty()); @@ -655,7 +697,9 @@ async fn test_aggregate_with_filter() { ], vec![], ); - scanner.aggregate(AggregateExpr::substrait(agg_bytes)); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let plan = scanner.create_plan().await.unwrap(); let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); @@ -692,7 +736,9 @@ async fn test_aggregate_empty_result() { vec![agg_extension(1, "count")], vec![], ); - scanner.aggregate(AggregateExpr::substrait(agg_bytes)); + scanner + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let plan = scanner.create_plan().await.unwrap(); let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); @@ -1006,7 +1052,8 @@ async fn test_vector_search_with_aggregate() { .unwrap() .project(&["id", "category"]) .unwrap() - .aggregate(AggregateExpr::substrait(agg_bytes)); + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let results = scanner.try_into_batch().await.unwrap(); @@ -1066,7 +1113,8 @@ async fn test_fts_with_aggregate() { .unwrap() .project(&["id", "category"]) .unwrap() - .aggregate(AggregateExpr::substrait(agg_bytes)); + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let results = scanner.try_into_batch().await.unwrap(); @@ -1128,7 +1176,8 @@ async fn test_vector_search_with_sum_aggregate() { .unwrap() .project(&["id", "category"]) .unwrap() - .aggregate(AggregateExpr::substrait(agg_bytes)); + .aggregate(AggregateExpr::substrait(agg_bytes)) + .unwrap(); let results = scanner.try_into_batch().await.unwrap(); @@ -1142,3 +1191,194 @@ async fn test_vector_search_with_sum_aggregate() { // Verify we have 2 columns: category and sum_id assert_eq!(results.num_columns(), 2); } + +#[tokio::test] +async fn test_scanner_count_rows() { + let ds = create_numeric_dataset("memory://test_count_rows", 2, 50).await; + + // Check plan structure + let mut scanner = ds.scan(); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + + // COUNT(*) should have empty projection (optimized to not read any columns) + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] + LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, row_id=false, row_addr=true, full_filter=--, refine_filter=--", + ) + .await + .unwrap(); + + // Execute and verify result + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 100 // 2 fragments * 50 rows + ); +} + +#[tokio::test] +async fn test_scanner_count_rows_with_filter() { + let ds = create_numeric_dataset("memory://test_count_rows_filter", 1, 100).await; + + // Check plan structure + let mut scanner = ds.scan(); + scanner.filter("x >= 50").unwrap(); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + + // COUNT(*) with filter: filter columns are needed, but no data columns for the aggregate + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] + LanceRead: uri=..., projection=[x], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=x >= Int64(50), refine_filter=x >= Int64(50)", + ) + .await + .unwrap(); + + // Execute and verify result + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + // x ranges from 0 to 99, so x >= 50 matches rows 50..99 (50 rows) + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 50 + ); +} + +#[tokio::test] +async fn test_scanner_count_rows_empty_result() { + let ds = create_numeric_dataset("memory://test_count_rows_empty", 1, 100).await; + + let mut scanner = ds.scan(); + scanner.filter("x > 1000").unwrap(); // No rows match + let count = scanner.count_rows().await.unwrap(); + + assert_eq!(count, 0); +} + +#[tokio::test] +async fn test_scanner_count_rows_with_vector_search() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create vector index + let params = VectorIndexParams::ivf_flat(2, MetricType::L2); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]); + + // Check plan structure + let mut scanner = dataset.scan(); + scanner.nearest("vec", &query_vector, 30).unwrap(); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] + SortExec: TopK(fetch=30), ... + ANNSubIndex: ... + ANNIvfPartition: ...deltas=1", + ) + .await + .unwrap(); + + // Execute and verify result + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 30 // top K results + ); +} + +#[tokio::test] +async fn test_scanner_count_rows_with_fts() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create FTS index on text column + dataset + .create_index( + &["text"], + IndexType::Inverted, + None, + &InvertedIndexParams::default(), + true, + ) + .await + .unwrap(); + + // Check plan structure + let mut scanner = dataset.scan(); + scanner + .full_text_search(FullTextSearchQuery::new("document".to_string())) + .unwrap(); + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + + assert_plan_node_equals( + plan.clone(), + "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] + MatchQuery: column=text, query=document", + ) + .await + .unwrap(); + + // Execute and verify result + let stream = execute_plan(plan, LanceExecutionOptions::default()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + // All 100 documents contain "document" + assert_eq!( + batches[0].column(0).as_primitive::().value(0), + 100 + ); +} + +#[tokio::test] +async fn test_scanner_count_rows_with_vector_search_and_filter() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let mut dataset = create_vector_text_dataset(uri, 100).await; + + // Create vector index + let params = VectorIndexParams::ivf_flat(2, MetricType::L2); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + // Vector search for top 50 results, then filter by category + let query_vector = Float32Array::from(vec![50.0f32, 50.0, 50.0, 50.0]); + + let mut scanner = dataset.scan(); + scanner + .nearest("vec", &query_vector, 50) + .unwrap() + .filter("category = 'category_a'") + .unwrap(); + let count = scanner.count_rows().await.unwrap(); + + // Only ~1/3 of the top 50 results should be in category_a + assert!(count > 0 && count <= 50); +}