diff --git a/datafusion/src/optimizer/single_distinct_to_groupby.rs b/datafusion/src/optimizer/single_distinct_to_groupby.rs index 3232fa03ce80f..9bddec997db6d 100644 --- a/datafusion/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/src/optimizer/single_distinct_to_groupby.rs @@ -20,7 +20,7 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Aggregate, Projection}; -use crate::logical_plan::{columnize_expr, DFSchema, Expr, LogicalPlan}; +use crate::logical_plan::{col, columnize_expr, DFSchema, Expr, LogicalPlan}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use hashbrown::HashSet; @@ -34,14 +34,16 @@ use std::sync::Arc; /// /// Into /// -/// SELECT F1(s),F2(s) +/// SELECT F1(alias1),F2(alias1) /// FROM ( -/// SELECT s, k ... GROUP BY s, k +/// SELECT s as alias1, k ... GROUP BY s, k /// ) /// GROUP BY k /// ``` pub struct SingleDistinctToGroupBy {} +const SINGLE_DISTINCT_ALIAS: &str = "alias1"; + impl SingleDistinctToGroupBy { #[allow(missing_docs)] pub fn new() -> Self { @@ -69,11 +71,12 @@ fn optimize(plan: &LogicalPlan) -> Result { if group_fields_set .insert(args[0].name(input.schema()).unwrap()) { - all_group_args.push(args[0].clone()); + all_group_args + .push(args[0].clone().alias(SINGLE_DISTINCT_ALIAS)); } Expr::AggregateFunction { fun: fun.clone(), - args: args.clone(), + args: vec![col(SINGLE_DISTINCT_ALIAS)], distinct: false, } } @@ -104,7 +107,6 @@ fn optimize(plan: &LogicalPlan) -> Result { ) .unwrap(), ); - let final_agg = LogicalPlan::Aggregate(Aggregate { input: Arc::new(grouped_agg.unwrap()), group_expr: group_expr.clone(), @@ -191,7 +193,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{col, count, count_distinct, max, LogicalPlanBuilder}; + use crate::logical_plan::{col, count, count_distinct, lit, max, LogicalPlanBuilder}; use crate::physical_plan::aggregates; use crate::test::*; @@ -229,9 +231,26 @@ mod tests { .build()?; // Should work - let expected = "Projection: #COUNT(test.b) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#test.b)]] [COUNT(test.b):UInt64;N]\ - \n Aggregate: groupBy=[[#test.b]], aggr=[[]] [b:UInt32]\ + let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ + \n Aggregate: groupBy=[[#test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ + \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + #[test] + fn single_distinct_expr() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? + .build()?; + + let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):UInt64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ + \n Aggregate: groupBy=[[Int32(2) * #test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); @@ -247,9 +266,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: #test.a AS a, #COUNT(test.b) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#test.b)]] [a:UInt32, COUNT(test.b):UInt64;N]\ - \n Aggregate: groupBy=[[#test.a, #test.b]], aggr=[[]] [a:UInt32, b:UInt32]\ + let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N]\ + \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); @@ -293,9 +312,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: #test.a AS a, #COUNT(test.b) AS COUNT(DISTINCT test.b), #MAX(test.b) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#test.b), MAX(#test.b)]] [a:UInt32, COUNT(test.b):UInt64;N, MAX(test.b):UInt32;N]\ - \n Aggregate: groupBy=[[#test.a, #test.b]], aggr=[[]] [a:UInt32, b:UInt32]\ + let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 243d0084d890e..8073862c8d6e5 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -101,6 +101,40 @@ async fn csv_query_count() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_count_distinct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT count(distinct c2) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------------------------+", + "| COUNT(DISTINCT aggregate_test_100.c2) |", + "+---------------------------------------+", + "| 5 |", + "+---------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_count_distinct_expr() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT count(distinct c2 % 2) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------------------------------------+", + "| COUNT(DISTINCT aggregate_test_100.c2 % Int64(2)) |", + "+--------------------------------------------------+", + "| 2 |", + "+--------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_count_star() { let mut ctx = ExecutionContext::new();