Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion java/lance-jni/src/blocking_scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ fn inner_create_scanner<'local>(

let substrait_aggregate_opt = env.get_bytes_opt(&substrait_aggregate_obj)?;
if let Some(substrait_aggregate) = substrait_aggregate_opt {
scanner.aggregate(AggregateExpr::substrait(substrait_aggregate));
scanner.aggregate(AggregateExpr::substrait(substrait_aggregate))?;
}

let scanner = BlockingScanner::create(scanner);
Expand Down
19 changes: 6 additions & 13 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,19 +1168,12 @@ def test_count_rows_via_scanner(tmp_path: Path):
ds = lance.write_dataset(pa.table({"a": range(100), "b": range(100)}), tmp_path)

assert ds.scanner(filter="a < 50", columns=[], with_row_id=True).count_rows() == 50

with pytest.raises(
ValueError, match="should not be called on a plan selecting columns"
):
ds.scanner(filter="a < 50", columns=["a"], with_row_id=True).count_rows()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to actually return a value instead of failing this case. With the latest way of optimizing, it will always just do a metadata projection and avoid data scan.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is doing a data scan isn't it? a is not indexed and so there is no way to satisfy the count request without actually scanning the column.

I don't entirely agree that this shouldn't be an error but I also don't disagree enough to complain. I think the only valid concern I could have is that a user doing something like... ds.scanner(columns=["a"]).count_rows() might think this is the same as SELECT COUNT(a) FROM ... (i.e. that it returns the count of non-null rows) but that's a pretty weak argument.

So...feel free to ignore this comment 😛

Copy link
Copy Markdown
Contributor Author

@jackye1995 jackye1995 Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh okay, I somehow ignored the filter part when trying to reason why it was asserting failure...

ds.scanner(columns=["a"]).count_rows() might think this is the same as SELECT COUNT(a) FROM ...

ohh I see the reasoning now, thanks for explaining. I think it is still clear, that ds.scanner(columns=["a"]).count_rows() is not the same as SELECT COUNT(a) because it should be equivalent to something like ds.scanner(columns=["a"], filter="a IS NOT NULL").count_rows().


with pytest.raises(
ValueError, match="should not be called on a plan selecting columns"
):
ds.scanner(with_row_id=True).count_rows()

with pytest.raises(ValueError, match="with_row_id is false"):
ds.scanner(columns=[]).count_rows()
assert (
ds.scanner(filter="a < 50", columns=["a"], with_row_id=True).count_rows() == 50
)
assert ds.scanner(with_row_id=True).count_rows() == 100
assert ds.scanner(columns=[]).count_rows() == 100
assert ds.scanner().count_rows() == 100


def test_select_none(tmp_path: Path):
Expand Down
4 changes: 3 additions & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,9 @@ impl Dataset {
.map_err(|err| PyValueError::new_err(err.to_string()))?;
}
if let Some(aggregate_bytes) = substrait_aggregate {
scanner.aggregate(AggregateExpr::substrait(aggregate_bytes));
scanner
.aggregate(AggregateExpr::substrait(aggregate_bytes))
.map_err(|err| PyValueError::new_err(err.to_string()))?;
}
let scan = Arc::new(scanner);
Ok(Scanner::new(scan))
Expand Down
25 changes: 25 additions & 0 deletions rust/lance-datafusion/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

use datafusion::logical_expr::Expr;

use crate::planner::Planner;

/// Aggregate specification with group by and aggregate expressions.
#[derive(Debug, Clone)]
pub struct Aggregate {
Expand All @@ -14,3 +16,26 @@ pub struct Aggregate {
/// Use `.alias()` on the expression to set output column names.
pub aggregates: Vec<Expr>,
}

impl Aggregate {
/// Create a new Aggregate.
pub fn new(group_by: Vec<Expr>, aggregates: Vec<Expr>) -> Self {
Self {
group_by,
aggregates,
}
}

/// Compute column names required by this aggregate.
///
/// For COUNT(*), this returns empty. For SUM(x), GROUP BY y, this returns [x, y].
pub fn required_columns(&self) -> Vec<String> {
let mut required_columns = Vec::new();
for expr in self.group_by.iter().chain(self.aggregates.iter()) {
required_columns.extend(Planner::column_names_in_expr(expr));
}
required_columns.sort();
required_columns.dedup();
required_columns
}
}
5 changes: 1 addition & 4 deletions rust/lance-datafusion/src/substrait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,7 @@ pub async fn parse_aggregate_rel_with_extensions(
let group_by = parse_groupings(aggregate_rel, &df_schema, &consumer).await?;
let aggregates = parse_measures(aggregate_rel, &df_schema, &consumer).await?;

Ok(Aggregate {
group_by,
aggregates,
})
Ok(Aggregate::new(group_by, aggregates))
}

/// Parse an AggregateRel proto with default extensions.
Expand Down
191 changes: 59 additions & 132 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use async_recursion::async_recursion;
use chrono::Utc;
use datafusion::common::{exec_datafusion_err, DFSchema, JoinType, NullEquality, SchemaExt};
use datafusion::functions_aggregate;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::logical_expr::{col, lit, Expr, ScalarUDF};
use datafusion::physical_expr::PhysicalSortExpr;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
Expand All @@ -24,7 +23,6 @@ use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::{
aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
display::DisplayableExecutionPlan,
expressions::Literal,
limit::GlobalLimitExec,
repartition::RepartitionExec,
union::UnionExec,
Expand Down Expand Up @@ -506,28 +504,28 @@ impl AggregateExpr {
}
}

fn to_aggregate(
&self,
#[allow(unused_variables)] schema: Arc<ArrowSchema>,
) -> Result<Aggregate> {
/// Parse into a unified Aggregate structure.
///
/// For Substrait, this parses the bytes into DataFusion expressions.
/// For DataFusion, this just wraps the expressions.
///
/// The schema is used to resolve field references in Substrait expressions.
fn parse(self, #[allow(unused_variables)] schema: Arc<ArrowSchema>) -> Result<Aggregate> {
match self {
#[cfg(feature = "substrait")]
Self::Substrait(bytes) => {
use lance_datafusion::exec::{get_session_context, LanceExecutionOptions};
use lance_datafusion::substrait::parse_substrait_aggregate;

let ctx = get_session_context(&LanceExecutionOptions::default());
parse_substrait_aggregate(bytes, schema, &ctx.state())
parse_substrait_aggregate(&bytes, schema, &ctx.state())
.now_or_never()
.expect("could not parse the Substrait aggregate in a synchronous fashion")
}
Self::Datafusion {
group_by,
aggregates,
} => Ok(Aggregate {
group_by: group_by.clone(),
aggregates: aggregates.clone(),
}),
} => Ok(Aggregate::new(group_by, aggregates)),
}
}
}
Expand Down Expand Up @@ -788,7 +786,7 @@ pub struct Scanner {
/// File reader options to use when reading data files.
file_reader_options: Option<FileReaderOptions>,

aggregate: Option<AggregateExpr>,
aggregate: Option<Aggregate>,

// Legacy fields to help migrate some old projection behavior to new behavior
//
Expand Down Expand Up @@ -1239,9 +1237,14 @@ impl Scanner {
}

/// Set aggregation.
pub fn aggregate(&mut self, aggregate: AggregateExpr) -> &mut Self {
self.aggregate = Some(aggregate);
self
///
/// The aggregate expression is parsed immediately using the dataset schema.
/// For Substrait aggregates, this converts them to DataFusion expressions.
pub fn aggregate(&mut self, aggregate: AggregateExpr) -> Result<&mut Self> {
let schema: Arc<ArrowSchema> = Arc::new(self.dataset.schema().into());
let parsed = aggregate.parse(schema)?;
self.aggregate = Some(parsed);
Ok(self)
}

/// Set the batch size.
Expand Down Expand Up @@ -1911,62 +1914,6 @@ impl Scanner {
Ok(concat_batches(&schema, &batches)?)
}

pub fn create_count_plan(&self) -> BoxFuture<'_, Result<Arc<dyn ExecutionPlan>>> {
// Future intentionally boxed here to avoid large futures on the stack
async move {
if self.projection_plan.physical_projection.is_empty() {
return Err(Error::invalid_input("count_rows called but with_row_id is false".to_string(), location!()));
}
if !self.projection_plan.physical_projection.is_metadata_only() {
let physical_schema = self.projection_plan.physical_projection.to_schema();
let columns: Vec<&str> = physical_schema.fields
.iter()
.map(|field| field.name.as_str())
.collect();

let msg = format!(
"count_rows should not be called on a plan selecting columns. selected columns: [{}]",
columns.join(", ")
);

return Err(Error::invalid_input(msg, location!()));
}

if self.limit.is_some() || self.offset.is_some() {
log::warn!(
"count_rows called with limit or offset which could have surprising results"
);
}

let plan = self.create_plan().await?;
// Datafusion interprets COUNT(*) as COUNT(1)
let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1))));

let input_phy_exprs: &[Arc<dyn PhysicalExpr>] = &[one];
let schema = plan.schema();

let mut builder = datafusion_physical_expr::aggregate::AggregateExprBuilder::new(
count_udaf(),
input_phy_exprs.to_vec(),
);
builder = builder.schema(schema);
builder = builder.alias("count_rows".to_string());

let count_expr = builder.build()?;

let plan_schema = plan.schema();
Ok(Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(Vec::new()),
vec![Arc::new(count_expr)],
vec![None],
plan,
plan_schema,
)?) as Arc<dyn ExecutionPlan>)
}
.boxed()
}

/// Scan and return the number of matching rows
///
/// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method
Expand All @@ -1975,8 +1922,11 @@ impl Scanner {
pub fn count_rows(&self) -> BoxFuture<'_, Result<u64>> {
// Future intentionally boxed here to avoid large futures on the stack
async move {
let count_plan = self.create_count_plan().await?;
let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?;
let mut scanner = self.clone();
scanner.aggregate(AggregateExpr::builder().count_star().build())?;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup!


let plan = scanner.create_plan().await?;
let mut stream = execute_plan(plan, LanceExecutionOptions::default())?;

// A count plan will always return a single batch with a single row.
if let Some(first_batch) = stream.next().await {
Expand All @@ -1986,7 +1936,7 @@ impl Scanner {
.as_any()
.downcast_ref::<Int64Array>()
.ok_or(Error::invalid_input(
"Count plan did not return a UInt64Array".to_string(),
"Count plan did not return an Int64Array".to_string(),
location!(),
))?;
Ok(array.value(0) as u64)
Expand Down Expand Up @@ -2018,12 +1968,11 @@ impl Scanner {
async fn apply_aggregate(
&self,
plan: Arc<dyn ExecutionPlan>,
agg_spec: &AggregateExpr,
agg: &Aggregate,
) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;

let schema = plan.schema();
let agg = agg_spec.to_aggregate(schema.clone())?;
let df_schema = DFSchema::try_from(schema.as_ref().clone())?;

let group_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = agg
Expand Down Expand Up @@ -2500,10 +2449,19 @@ impl Scanner {
plan = filter_plan.refine_filter(plan, self).await?;

// Aggregate (if set, applies aggregate and returns early)
if let Some(agg_spec) = &self.aggregate {
// Take columns needed for aggregation
plan = self.take(plan, self.projection_plan.physical_projection.clone())?;
return self.apply_aggregate(plan, agg_spec).await;
if let Some(agg) = &self.aggregate {
// Take only columns needed by the aggregate, not the full projection.
// For COUNT(*), this is empty. For SUM(x), this is just [x].
let required_columns = agg.required_columns();
let agg_projection = if required_columns.is_empty() {
self.dataset.empty_projection()
} else {
self.dataset
.empty_projection()
.union_columns(&required_columns, OnMissing::Error)?
};
plan = self.take(plan, agg_projection)?;
return self.apply_aggregate(plan, agg).await;
}

// Sort
Expand Down Expand Up @@ -2819,16 +2777,35 @@ impl Scanner {
filter_plan: &mut ExprFilterPlan,
) -> Result<PlannedFilteredScan> {
log::trace!("source is a filtered read");

// Compute the effective projection based on what's actually needed.
// If we have an aggregate, we only need the columns referenced by the aggregate,
// not all the columns from the projection plan.
let effective_projection = if let Some(agg) = &self.aggregate {
let required_columns = agg.required_columns();
if required_columns.is_empty() {
// COUNT(*) or similar - no columns needed
self.dataset.empty_projection()
} else {
// Aggregate needs specific columns
self.dataset
.empty_projection()
.union_columns(&required_columns, OnMissing::Error)?
}
} else {
self.projection_plan.physical_projection.clone()
};

let mut projection = if filter_plan.has_refine() {
// If the filter plan has two steps (a scalar indexed portion and a refine portion) then
// it makes sense to grab cheap columns during the first step to avoid taking them for
// the second step.
self.calc_eager_projection(filter_plan, &self.projection_plan.physical_projection)?
self.calc_eager_projection(filter_plan, &effective_projection)?
.with_row_id()
} else {
// If the filter plan only has one step then we just do a filtered read of all the
// columns that the user asked for.
self.projection_plan.physical_projection.clone()
effective_projection
};

if projection.is_empty() {
Expand Down Expand Up @@ -7373,56 +7350,6 @@ mod test {
assert_plan_node_equals(exec_plan, expected).await
}

#[tokio::test]
async fn test_count_plan() {
// A count rows operation should load the minimal amount of data
let dim = 256;
let fixture = TestVectorDataset::new_with_dimension(LanceFileVersion::Stable, true, dim)
.await
.unwrap();

// By default, all columns are returned, this is bad for a count_rows op
let err = fixture
.dataset
.scan()
.create_count_plan()
.await
.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));

let mut scan = fixture.dataset.scan();
scan.project(&Vec::<String>::default()).unwrap();

// with_row_id needs to be specified
let err = scan.create_count_plan().await.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));

scan.with_row_id();

let plan = scan.create_count_plan().await.unwrap();

assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[], aggr=[count_rows]
LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=--, refine_filter=--",
)
.await
.unwrap();

scan.filter("s == ''").unwrap();

let plan = scan.create_count_plan().await.unwrap();

assert_plan_node_equals(
plan,
"AggregateExec: mode=Single, gby=[], aggr=[count_rows]
ProjectionExec: expr=[_rowid@1 as _rowid]
LanceRead: uri=..., projection=[s], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=s = Utf8(\"\"), refine_filter=s = Utf8(\"\")",
)
.await
.unwrap();
}

#[tokio::test]
async fn test_inexact_scalar_index_plans() {
let data = gen_batch()
Expand Down
Loading