-
Notifications
You must be signed in to change notification settings - Fork 647
fix: remove unnecessary column projection for count aggregate #5950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jackye1995
merged 8 commits into
lance-format:main
from
jackye1995:deprecate-count-rows
Feb 14, 2026
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a49a091
refactor: use scanner aggregate for count_rows
jackye1995 cb1bfaf
cleanup
jackye1995 7c33453
fix optimization
jackye1995 37c9cba
fix test
jackye1995 fcccfc6
fix lint
jackye1995 552edae
avoid duplicated call to aggregate_required_columns
jackye1995 009df29
address comments
jackye1995 37be886
fix clippy
jackye1995 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,6 @@ use async_recursion::async_recursion; | |
| use chrono::Utc; | ||
| use datafusion::common::{exec_datafusion_err, DFSchema, JoinType, NullEquality, SchemaExt}; | ||
| use datafusion::functions_aggregate; | ||
| use datafusion::functions_aggregate::count::count_udaf; | ||
| use datafusion::logical_expr::{col, lit, Expr, ScalarUDF}; | ||
| use datafusion::physical_expr::PhysicalSortExpr; | ||
| use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; | ||
|
|
@@ -24,7 +23,6 @@ use datafusion::physical_plan::sorts::sort::SortExec; | |
| use datafusion::physical_plan::{ | ||
| aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, | ||
| display::DisplayableExecutionPlan, | ||
| expressions::Literal, | ||
| limit::GlobalLimitExec, | ||
| repartition::RepartitionExec, | ||
| union::UnionExec, | ||
|
|
@@ -506,28 +504,28 @@ impl AggregateExpr { | |
| } | ||
| } | ||
|
|
||
| fn to_aggregate( | ||
| &self, | ||
| #[allow(unused_variables)] schema: Arc<ArrowSchema>, | ||
| ) -> Result<Aggregate> { | ||
| /// Parse into a unified Aggregate structure. | ||
| /// | ||
| /// For Substrait, this parses the bytes into DataFusion expressions. | ||
| /// For DataFusion, this just wraps the expressions. | ||
| /// | ||
| /// The schema is used to resolve field references in Substrait expressions. | ||
| fn parse(self, #[allow(unused_variables)] schema: Arc<ArrowSchema>) -> Result<Aggregate> { | ||
| match self { | ||
| #[cfg(feature = "substrait")] | ||
| Self::Substrait(bytes) => { | ||
| use lance_datafusion::exec::{get_session_context, LanceExecutionOptions}; | ||
| use lance_datafusion::substrait::parse_substrait_aggregate; | ||
|
|
||
| let ctx = get_session_context(&LanceExecutionOptions::default()); | ||
| parse_substrait_aggregate(bytes, schema, &ctx.state()) | ||
| parse_substrait_aggregate(&bytes, schema, &ctx.state()) | ||
| .now_or_never() | ||
| .expect("could not parse the Substrait aggregate in a synchronous fashion") | ||
| } | ||
| Self::Datafusion { | ||
| group_by, | ||
| aggregates, | ||
| } => Ok(Aggregate { | ||
| group_by: group_by.clone(), | ||
| aggregates: aggregates.clone(), | ||
| }), | ||
| } => Ok(Aggregate::new(group_by, aggregates)), | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -788,7 +786,7 @@ pub struct Scanner { | |
| /// File reader options to use when reading data files. | ||
| file_reader_options: Option<FileReaderOptions>, | ||
|
|
||
| aggregate: Option<AggregateExpr>, | ||
| aggregate: Option<Aggregate>, | ||
|
|
||
| // Legacy fields to help migrate some old projection behavior to new behavior | ||
| // | ||
|
|
@@ -1239,9 +1237,14 @@ impl Scanner { | |
| } | ||
|
|
||
| /// Set aggregation. | ||
| pub fn aggregate(&mut self, aggregate: AggregateExpr) -> &mut Self { | ||
| self.aggregate = Some(aggregate); | ||
| self | ||
| /// | ||
| /// The aggregate expression is parsed immediately using the dataset schema. | ||
| /// For Substrait aggregates, this converts them to DataFusion expressions. | ||
| pub fn aggregate(&mut self, aggregate: AggregateExpr) -> Result<&mut Self> { | ||
| let schema: Arc<ArrowSchema> = Arc::new(self.dataset.schema().into()); | ||
| let parsed = aggregate.parse(schema)?; | ||
| self.aggregate = Some(parsed); | ||
| Ok(self) | ||
| } | ||
|
|
||
| /// Set the batch size. | ||
|
|
@@ -1911,62 +1914,6 @@ impl Scanner { | |
| Ok(concat_batches(&schema, &batches)?) | ||
| } | ||
|
|
||
| pub fn create_count_plan(&self) -> BoxFuture<'_, Result<Arc<dyn ExecutionPlan>>> { | ||
| // Future intentionally boxed here to avoid large futures on the stack | ||
| async move { | ||
| if self.projection_plan.physical_projection.is_empty() { | ||
| return Err(Error::invalid_input("count_rows called but with_row_id is false".to_string(), location!())); | ||
| } | ||
| if !self.projection_plan.physical_projection.is_metadata_only() { | ||
| let physical_schema = self.projection_plan.physical_projection.to_schema(); | ||
| let columns: Vec<&str> = physical_schema.fields | ||
| .iter() | ||
| .map(|field| field.name.as_str()) | ||
| .collect(); | ||
|
|
||
| let msg = format!( | ||
| "count_rows should not be called on a plan selecting columns. selected columns: [{}]", | ||
| columns.join(", ") | ||
| ); | ||
|
|
||
| return Err(Error::invalid_input(msg, location!())); | ||
| } | ||
|
|
||
| if self.limit.is_some() || self.offset.is_some() { | ||
| log::warn!( | ||
| "count_rows called with limit or offset which could have surprising results" | ||
| ); | ||
| } | ||
|
|
||
| let plan = self.create_plan().await?; | ||
| // Datafusion interprets COUNT(*) as COUNT(1) | ||
| let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); | ||
|
|
||
| let input_phy_exprs: &[Arc<dyn PhysicalExpr>] = &[one]; | ||
| let schema = plan.schema(); | ||
|
|
||
| let mut builder = datafusion_physical_expr::aggregate::AggregateExprBuilder::new( | ||
| count_udaf(), | ||
| input_phy_exprs.to_vec(), | ||
| ); | ||
| builder = builder.schema(schema); | ||
| builder = builder.alias("count_rows".to_string()); | ||
|
|
||
| let count_expr = builder.build()?; | ||
|
|
||
| let plan_schema = plan.schema(); | ||
| Ok(Arc::new(AggregateExec::try_new( | ||
| AggregateMode::Single, | ||
| PhysicalGroupBy::new_single(Vec::new()), | ||
| vec![Arc::new(count_expr)], | ||
| vec![None], | ||
| plan, | ||
| plan_schema, | ||
| )?) as Arc<dyn ExecutionPlan>) | ||
| } | ||
| .boxed() | ||
| } | ||
|
|
||
| /// Scan and return the number of matching rows | ||
| /// | ||
| /// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method | ||
|
|
@@ -1975,8 +1922,11 @@ impl Scanner { | |
| pub fn count_rows(&self) -> BoxFuture<'_, Result<u64>> { | ||
| // Future intentionally boxed here to avoid large futures on the stack | ||
| async move { | ||
| let count_plan = self.create_count_plan().await?; | ||
| let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; | ||
| let mut scanner = self.clone(); | ||
| scanner.aggregate(AggregateExpr::builder().count_star().build())?; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice cleanup! |
||
|
|
||
| let plan = scanner.create_plan().await?; | ||
| let mut stream = execute_plan(plan, LanceExecutionOptions::default())?; | ||
|
|
||
| // A count plan will always return a single batch with a single row. | ||
| if let Some(first_batch) = stream.next().await { | ||
|
|
@@ -1986,7 +1936,7 @@ impl Scanner { | |
| .as_any() | ||
| .downcast_ref::<Int64Array>() | ||
| .ok_or(Error::invalid_input( | ||
| "Count plan did not return a UInt64Array".to_string(), | ||
| "Count plan did not return an Int64Array".to_string(), | ||
| location!(), | ||
| ))?; | ||
| Ok(array.value(0) as u64) | ||
|
|
@@ -2018,12 +1968,11 @@ impl Scanner { | |
| async fn apply_aggregate( | ||
| &self, | ||
| plan: Arc<dyn ExecutionPlan>, | ||
| agg_spec: &AggregateExpr, | ||
| agg: &Aggregate, | ||
| ) -> Result<Arc<dyn ExecutionPlan>> { | ||
| use datafusion_physical_expr::aggregate::AggregateFunctionExpr; | ||
|
|
||
| let schema = plan.schema(); | ||
| let agg = agg_spec.to_aggregate(schema.clone())?; | ||
| let df_schema = DFSchema::try_from(schema.as_ref().clone())?; | ||
|
|
||
| let group_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = agg | ||
|
|
@@ -2500,10 +2449,19 @@ impl Scanner { | |
| plan = filter_plan.refine_filter(plan, self).await?; | ||
|
|
||
| // Aggregate (if set, applies aggregate and returns early) | ||
| if let Some(agg_spec) = &self.aggregate { | ||
| // Take columns needed for aggregation | ||
| plan = self.take(plan, self.projection_plan.physical_projection.clone())?; | ||
| return self.apply_aggregate(plan, agg_spec).await; | ||
| if let Some(agg) = &self.aggregate { | ||
| // Take only columns needed by the aggregate, not the full projection. | ||
| // For COUNT(*), this is empty. For SUM(x), this is just [x]. | ||
| let required_columns = agg.required_columns(); | ||
| let agg_projection = if required_columns.is_empty() { | ||
| self.dataset.empty_projection() | ||
| } else { | ||
| self.dataset | ||
| .empty_projection() | ||
| .union_columns(&required_columns, OnMissing::Error)? | ||
| }; | ||
| plan = self.take(plan, agg_projection)?; | ||
| return self.apply_aggregate(plan, agg).await; | ||
| } | ||
|
|
||
| // Sort | ||
|
|
@@ -2819,16 +2777,35 @@ impl Scanner { | |
| filter_plan: &mut ExprFilterPlan, | ||
| ) -> Result<PlannedFilteredScan> { | ||
| log::trace!("source is a filtered read"); | ||
|
|
||
| // Compute the effective projection based on what's actually needed. | ||
| // If we have an aggregate, we only need the columns referenced by the aggregate, | ||
| // not all the columns from the projection plan. | ||
| let effective_projection = if let Some(agg) = &self.aggregate { | ||
| let required_columns = agg.required_columns(); | ||
| if required_columns.is_empty() { | ||
| // COUNT(*) or similar - no columns needed | ||
| self.dataset.empty_projection() | ||
| } else { | ||
| // Aggregate needs specific columns | ||
| self.dataset | ||
| .empty_projection() | ||
| .union_columns(&required_columns, OnMissing::Error)? | ||
| } | ||
| } else { | ||
| self.projection_plan.physical_projection.clone() | ||
| }; | ||
|
|
||
| let mut projection = if filter_plan.has_refine() { | ||
| // If the filter plan has two steps (a scalar indexed portion and a refine portion) then | ||
| // it makes sense to grab cheap columns during the first step to avoid taking them for | ||
| // the second step. | ||
| self.calc_eager_projection(filter_plan, &self.projection_plan.physical_projection)? | ||
| self.calc_eager_projection(filter_plan, &effective_projection)? | ||
| .with_row_id() | ||
| } else { | ||
| // If the filter plan only has one step then we just do a filtered read of all the | ||
| // columns that the user asked for. | ||
| self.projection_plan.physical_projection.clone() | ||
| effective_projection | ||
| }; | ||
|
|
||
| if projection.is_empty() { | ||
|
|
@@ -7373,56 +7350,6 @@ mod test { | |
| assert_plan_node_equals(exec_plan, expected).await | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_count_plan() { | ||
| // A count rows operation should load the minimal amount of data | ||
| let dim = 256; | ||
| let fixture = TestVectorDataset::new_with_dimension(LanceFileVersion::Stable, true, dim) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| // By default, all columns are returned, this is bad for a count_rows op | ||
| let err = fixture | ||
| .dataset | ||
| .scan() | ||
| .create_count_plan() | ||
| .await | ||
| .unwrap_err(); | ||
| assert!(matches!(err, Error::InvalidInput { .. })); | ||
|
|
||
| let mut scan = fixture.dataset.scan(); | ||
| scan.project(&Vec::<String>::default()).unwrap(); | ||
|
|
||
| // with_row_id needs to be specified | ||
| let err = scan.create_count_plan().await.unwrap_err(); | ||
| assert!(matches!(err, Error::InvalidInput { .. })); | ||
|
|
||
| scan.with_row_id(); | ||
|
|
||
| let plan = scan.create_count_plan().await.unwrap(); | ||
|
|
||
| assert_plan_node_equals( | ||
| plan, | ||
| "AggregateExec: mode=Single, gby=[], aggr=[count_rows] | ||
| LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=--, refine_filter=--", | ||
| ) | ||
| .await | ||
| .unwrap(); | ||
|
|
||
| scan.filter("s == ''").unwrap(); | ||
|
|
||
| let plan = scan.create_count_plan().await.unwrap(); | ||
|
|
||
| assert_plan_node_equals( | ||
| plan, | ||
| "AggregateExec: mode=Single, gby=[], aggr=[count_rows] | ||
| ProjectionExec: expr=[_rowid@1 as _rowid] | ||
| LanceRead: uri=..., projection=[s], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=s = Utf8(\"\"), refine_filter=s = Utf8(\"\")", | ||
| ) | ||
| .await | ||
| .unwrap(); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_inexact_scalar_index_plans() { | ||
| let data = gen_batch() | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes sense to actually return a value instead of failing this case. With the latest way of optimizing, it will always just do a metadata projection and avoid data scan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is doing a data scan isn't it?
ais not indexed and so there is no way to satisfy the count request without actually scanning the column.I don't entirely agree that this shouldn't be an error but I also don't disagree enough to complain. I think the only valid concern I could have is that a user doing something like...
ds.scanner(columns=["a"]).count_rows()might think this is the same asSELECT COUNT(a) FROM ...(i.e. that it returns the count of non-null rows) but that's a pretty weak argument.So...feel free to ignore this comment 😛
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh okay, I somehow ignored the filter part when trying to reason why it was asserting failure...
ohh I see the reasoning now, thanks for explaining. I think it is still clear, that
ds.scanner(columns=["a"]).count_rows()is not the same asSELECT COUNT(a)because it should be equivalent to something likeds.scanner(columns=["a"], filter="a IS NOT NULL").count_rows().