Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ pub fn check_arg_count(
);
}
}
TypeSignature::Nullary => {
if !input_types.is_empty() {
return plan_err!("The function {func_name} expects no arguments");
}
}
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_) => {
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ impl fmt::Display for AggregateUDF {
}

/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
#[derive(Debug)]
pub struct StatisticsArgs<'a> {
/// The statistics of the aggregate input
pub statistics: &'a Statistics,
Expand Down
16 changes: 14 additions & 2 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ impl AggregateUDFImpl for Count {
"count"
}

// In AggregateFunctionPlanner, wildcard is converted to count(1)
//
// count() -> count(1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still can't run select count(), count(*).

> select count(), count(*);
Error during planning: Projections require unique expression names but the expression "count(*)" at position 0 and "count(*)" at position 1 have the same name. Consider aliasing ("AS") one of them

I suspect that using aliases to restore the original names is a simpler fix. I tried doing this on jonahgao@08206fd.

@jayzhan211 jayzhan211 Feb 24, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue here is quite different than the test covered in extended test.

duplicated schema case is executable now

query error DataFusion error: Schema error: Schema contains duplicate unqualified field name "count\(\*\)"
select count(1) * count(2);

select count(), count(*) duplicated name in projection is another issue

But I agree, this query should be executable too, and I think the way to fix it is different from the duplicated schema name one

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I verified that both of those queries run in datafusion 44 and 45 but does not run on main. Thus this is a regression.

I agree with @jayzhan211 that the issue is different than what is causing the sqlite tests to fail in main

I have filed a ticket to track this:

@jonahgao jonahgao Feb 24, 2025

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they have the same root cause, which is the rewriting by AggregateFunctionPlanner and Count::schema_name() introducing duplicate names, and they could all be fixed by using aliases. The old CountWildcardRule used NamePreserver to achieve a similar effect.

// count(*) -> count(1)
// count(1) -> count(1)
// count(2) -> count(2)
//
// count(1) is named as count(*) in schema_name
// other constant remains the same
fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
let AggregateFunctionParams {
args,
Expand Down Expand Up @@ -511,6 +520,11 @@ impl AggregateUDFImpl for Count {
return None;
}
if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
// handle count()
if statistics_args.exprs.is_empty() {
return Some(ScalarValue::Int64(Some(num_rows as i64)));
}

if statistics_args.exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = statistics_args.exprs[0]
Expand Down Expand Up @@ -550,8 +564,6 @@ impl AggregateUDFImpl for Count {
fn is_count_wildcard(args: &[Expr]) -> bool {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function now feels a bit redundant as it is just checking for .empty()

match args {
[] => true, // count()
// All const should be coerced to int64 or rejected by the signature
[Expr::Literal(ScalarValue::Int64(Some(_)))] => true, // count(1)
_ => false, // More than one argument or non-matching cases
}
}
Expand Down
4 changes: 1 addition & 3 deletions datafusion/functions-aggregate/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
use datafusion_common::Result;
use datafusion_expr::{
expr::AggregateFunction,
lit,
planner::{ExprPlanner, PlannerResult, RawAggregateExpr},
utils::COUNT_STAR_EXPANSION,
Expr,
};

Expand All @@ -49,7 +47,7 @@ impl ExprPlanner for AggregateFunctionPlanner {
return Ok(PlannerResult::Planned(Expr::AggregateFunction(
AggregateFunction::new_udf(
func,
vec![lit(COUNT_STAR_EXPANSION)],
vec![],
distinct,
filter,
order_by,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/physical-expr/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ impl AggregateExprBuilder {
is_distinct,
is_reversed,
} = self;
if args.is_empty() {
// only count function can have empty args
if args.is_empty() && fun.name() != "count" {
return internal_err!("args should not be empty");
}

Expand Down
11 changes: 10 additions & 1 deletion datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ use crate::{
SendableRecordBatchStream, Statistics,
};

use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array};
use arrow::array::{
ArrayRef, Int64Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::stats::Precision;
Expand Down Expand Up @@ -1231,6 +1233,13 @@ fn evaluate(
expr: &[Arc<dyn PhysicalExpr>],
batch: &RecordBatch,
) -> Result<Vec<ArrayRef>> {
// handle count() case
if expr.is_empty() {
return Ok(vec![
Arc::new(Int64Array::from(vec![1; batch.num_rows()])) as ArrayRef

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent to count(1) case

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this function is not only used by count. I'm not quite sure about the impact of this change.
Ideally, this function should not involve the logic of any specific aggregation function.

]);
}

expr.iter()
.map(|expr| {
expr.evaluate(batch)
Expand Down
22 changes: 13 additions & 9 deletions datafusion/physical-plan/src/aggregates/no_grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::aggregates::{
};
use crate::metrics::{BaselineMetrics, RecordOutput};
use crate::{RecordBatchStream, SendableRecordBatchStream};
use arrow::array::{ArrayRef, Int64Array};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
Expand Down Expand Up @@ -219,23 +220,26 @@ fn aggregate_batch(
None => Cow::Borrowed(&batch),
};

let n_rows = batch.num_rows();

// 1.3
let values = &expr
.iter()
.map(|e| {
e.evaluate(&batch)
.and_then(|v| v.into_array(batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
// Handle count(*) case
let values = if expr.is_empty() {
vec![Arc::new(Int64Array::from(vec![1; n_rows])) as ArrayRef]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent to count(1) case

} else {
expr.iter()
.map(|e| e.evaluate(&batch).and_then(|v| v.into_array(n_rows)))
.collect::<Result<Vec<_>>>()?
};

// 1.4
let size_pre = accum.size();
let res = match mode {
AggregateMode::Partial
| AggregateMode::Single
| AggregateMode::SinglePartitioned => accum.update_batch(values),
| AggregateMode::SinglePartitioned => accum.update_batch(&values),
AggregateMode::Final | AggregateMode::FinalPartitioned => {
accum.merge_batch(values)
accum.merge_batch(&values)
}
};
let size_post = accum.size();
Expand Down
51 changes: 51 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6276,6 +6276,9 @@ physical_plan
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5], file_type=csv, has_header=true

statement count 0
drop table aggregate_test_100;

# test count(null) case (null with type)

statement count 0
Expand All @@ -6296,6 +6299,54 @@ physical_plan
01)AggregateExec: mode=Single, gby=[], aggr=[count(NULL)]
02)--DataSourceExec: partitions=1, partition_sizes=[1]

statement count 0
drop table t;

# test duplicated shema name issue

statement count 0
create table t (a int) as values (1), (2);

query I
select count() from t;
----
2

query I
select count(1) * count(2) from t;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also please add a test that shows just the values of count(2)

For example

select count(1), count(2), count(1) * count(2) from t;

----
4

query I
select count(1) * count(*) from t;
----
4

query I
select count(*) * count(*) from t;
----
4

query I

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise here it would be nice to have count(1) and count(2) individually tested

select count(1) * count(1) from t;
----
4

query TT
explain select count(1) * count(2) from t;
----
logical_plan
01)Projection: count(Int64(1)) * count(Int64(2))
02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1)), count(Int64(2))]]
03)----TableScan: t projection=[]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 * count(Int64(2))@1 as count(Int64(1)) * count(Int64(2))]
02)--AggregateExec: mode=Single, gby=[], aggr=[count(Int64(1)), count(Int64(2))]
03)----DataSourceExec: partitions=1, partition_sizes=[1]

statement count 0
drop table t;

#######
# Group median test
#######
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/count_star_rule.slt
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ query TT
EXPLAIN SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1;
----
logical_plan
01)Projection: t1.a, count(*) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count_a
02)--WindowAggr: windowExpr=[[count(*) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
01)Projection: t1.a, count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count_a

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given you implemented support for count() I don't understand why this is this changed to count(1) (why isn't it count()?`

02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
03)----TableScan: t1 projection=[a]
physical_plan
01)ProjectionExec: expr=[a@0 as a, count(*) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as count_a]
02)--WindowAggExec: wdw=[count(*) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]
01)ProjectionExec: expr=[a@0 as a, count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as count_a]
02)--WindowAggExec: wdw=[count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]
03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]
04)------DataSourceExec: partitions=1, partition_sizes=[1]

Expand Down
Loading