Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions dask_planner/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,21 @@ impl PyExpr {

#[pyo3(name = "getFilterExpr")]
pub fn get_filter_expr(&self) -> PyResult<Option<PyExpr>> {
// TODO refactor to avoid duplication
match &self.expr {
Expr::Alias(expr, _) => match expr.as_ref() {
Expr::AggregateFunction { filter, .. } | Expr::AggregateUDF { filter, .. } => {
match filter {
Some(filter) => {
Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone())))
}
None => Ok(None),
}
}
_ => Err(py_type_err(
"getFilterExpr() - Non-aggregate expression encountered",
)),
},
Expr::AggregateFunction { filter, .. } | Expr::AggregateUDF { filter, .. } => {
match filter {
Some(filter) => {
Expand Down
10 changes: 9 additions & 1 deletion dask_planner/src/sql/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use datafusion_optimizer::{
projection_push_down::ProjectionPushDown, subquery_filter_to_join::SubqueryFilterToJoin,
OptimizerConfig,
};
use log::trace;

mod eliminate_agg_distinct;
use eliminate_agg_distinct::EliminateAggDistinct;
Expand Down Expand Up @@ -55,7 +56,14 @@ impl DaskSqlOptimizer {
let mut resulting_plan: LogicalPlan = plan;
for optimization in &self.optimizations {
match optimization.optimize(&resulting_plan, &mut OptimizerConfig::new()) {
Ok(optimized_plan) => resulting_plan = optimized_plan,
Ok(optimized_plan) => {
trace!(
"== AFTER APPLYING RULE {} ==\n{}",
optimization.name(),
optimized_plan.display_indent()
);
resulting_plan = optimized_plan
}
Err(e) => {
println!(
"Skipping optimizer rule {} due to unexpected error: {}",
Expand Down
67 changes: 52 additions & 15 deletions dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
//! Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__4]]\
//! TableScan: a

use datafusion_common::{DFSchema, Result, ScalarValue};
use datafusion_common::{Column, DFSchema, Result};
use datafusion_expr::logical_plan::Projection;
use datafusion_expr::utils::exprlist_to_fields;
use datafusion_expr::{
Expand Down Expand Up @@ -136,7 +136,7 @@ impl OptimizerRule for EliminateAggDistinct {
.collect::<Result<Vec<_>>>()?;

for plan in &plans {
trace!("FINAL PLAN:\n{}", plan.display_indent());
trace!("{}", plan.display_indent());
}

match plans.len() {
Expand All @@ -151,7 +151,9 @@ impl OptimizerRule for EliminateAggDistinct {
for plan in plans.iter().skip(1) {
builder = builder.cross_join(plan)?;
}
builder.build()
let join_plan = builder.build()?;
trace!("{}", join_plan.display_indent_schema());
Ok(join_plan)
}
}
}
Expand Down Expand Up @@ -214,7 +216,9 @@ fn create_plan(
let mut group_expr = group_expr.clone();
group_expr.push(expr.clone());
let alias = format!("__dask_sql_count__{}", optimizer_config.next_id());
let aggr_expr = vec![count(Expr::Literal(ScalarValue::UInt64(Some(1)))).alias(&alias)];
let expr_name = expr.name()?;
let count_expr = Expr::Column(Column::from_qualified_name(&expr_name));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main fix - COUNT(col) instead of COUNT(1)

let aggr_expr = vec![count(count_expr).alias(&alias)];
let mut schema_expr = group_expr.clone();
schema_expr.extend_from_slice(&aggr_expr);
let schema = DFSchema::new_with_metadata(
Expand Down Expand Up @@ -461,20 +465,31 @@ fn unique_set_without_aliases(unique_expressions: &HashSet<Expr>) -> HashSet<Exp
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::optimizer::DaskSqlOptimizer;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::{
col, count, count_distinct,
logical_plan::{builder::LogicalTableSource, LogicalPlanBuilder},
};
use std::sync::Arc;

/// Optimize with just the eliminate_agg_distinct rule
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let rule = EliminateAggDistinct::new();
let optimized_plan = rule
.optimize(plan, &mut OptimizerConfig::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{}", optimized_plan.display_indent());
assert_eq!(expected, formatted_plan);
}

/// Optimize with all of the optimizer rules, including eliminate_agg_distinct
fn assert_fully_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let optimizer = DaskSqlOptimizer::new();
let optimized_plan = optimizer
.run_optimizations(plan.clone())
.expect("failed to optimize plan");
let formatted_plan = format!("{}", optimized_plan.display_indent());
assert_eq!(expected, formatted_plan);
}

Expand Down Expand Up @@ -575,7 +590,7 @@ mod tests {

let expected = "Projection: #a.b, #a.b AS COUNT(a.a), #SUM(__dask_sql_count__1) AS COUNT(DISTINCT a.a)\
\n Aggregate: groupBy=[[#a.b]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\
\n Aggregate: groupBy=[[#a.b, #a.a]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__1]]\
\n Aggregate: groupBy=[[#a.b, #a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\
\n TableScan: a";
assert_optimized_plan_eq(&plan, expected);
Ok(())
Expand All @@ -593,7 +608,7 @@ mod tests {

let expected = "Projection: #SUM(__dask_sql_count__1) AS COUNT(a.a), #COUNT(a.a) AS COUNT(DISTINCT a.a)\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__1]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\
\n TableScan: a";
assert_optimized_plan_eq(&plan, expected);
Ok(())
Expand All @@ -614,7 +629,7 @@ mod tests {

let expected = "Projection: #SUM(__dask_sql_count__1) AS c_a, #COUNT(a.a) AS cd_a\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__1]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\
\n TableScan: a";
assert_optimized_plan_eq(&plan, expected);
Ok(())
Expand Down Expand Up @@ -644,19 +659,19 @@ mod tests {
\n CrossJoin:\
\n Projection: #SUM(__dask_sql_count__1) AS COUNT(a.a), #COUNT(a.a) AS COUNT(DISTINCT a.a)\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__1]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\
\n TableScan: a\
\n Projection: #SUM(__dask_sql_count__2) AS COUNT(a.b), #COUNT(a.b) AS COUNT(DISTINCT a.b)\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__2), COUNT(#a.b)]]\
\n Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__2]]\
\n Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(#a.b) AS __dask_sql_count__2]]\
\n TableScan: a\
\n Projection: #SUM(__dask_sql_count__3) AS COUNT(a.c), #COUNT(a.c) AS COUNT(DISTINCT a.c)\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__3), COUNT(#a.c)]]\
\n Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__3]]\
\n Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(#a.c) AS __dask_sql_count__3]]\
\n TableScan: a\
\n Projection: #SUM(__dask_sql_count__4) AS COUNT(a.d), #COUNT(a.d) AS COUNT(DISTINCT a.d)\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__4), COUNT(#a.d)]]\
\n Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__4]]\
\n Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(#a.d) AS __dask_sql_count__4]]\
\n TableScan: a";
assert_optimized_plan_eq(&plan, expected);
Ok(())
Expand Down Expand Up @@ -686,21 +701,43 @@ mod tests {
\n CrossJoin:\
\n Projection: #SUM(__dask_sql_count__1) AS c_a, #COUNT(a.a) AS cd_a\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__1]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\
\n TableScan: a\
\n Projection: #SUM(__dask_sql_count__2) AS c_b, #COUNT(a.b) AS cd_b\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__2), COUNT(#a.b)]]\
\n Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__2]]\
\n Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(#a.b) AS __dask_sql_count__2]]\
\n TableScan: a\
\n Projection: #SUM(__dask_sql_count__3) AS c_c, #COUNT(a.c) AS cd_c\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__3), COUNT(#a.c)]]\
\n Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__3]]\
\n Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(#a.c) AS __dask_sql_count__3]]\
\n TableScan: a\
\n Projection: #SUM(__dask_sql_count__4) AS c_d, #COUNT(a.d) AS cd_d\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__4), COUNT(#a.d)]]\
\n Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__4]]\
\n Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(#a.d) AS __dask_sql_count__4]]\
\n TableScan: a";
assert_optimized_plan_eq(&plan, expected);

let expected = "CrossJoin:\
\n CrossJoin:\
\n CrossJoin:\
\n Projection: #SUM(__dask_sql_count__1) AS c_a, #COUNT(a.a) AS cd_a\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\
\n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\
\n TableScan: a projection=[a, b, c, d]\
\n Projection: #SUM(__dask_sql_count__2) AS c_b, #COUNT(a.b) AS cd_b\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__2), COUNT(#a.b)]]\
\n Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(#a.b) AS __dask_sql_count__2]]\
\n TableScan: a projection=[a, b, c, d]\
\n Projection: #SUM(__dask_sql_count__3) AS c_c, #COUNT(a.c) AS cd_c\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__3), COUNT(#a.c)]]\
\n Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(#a.c) AS __dask_sql_count__3]]\
\n TableScan: a projection=[a, b, c, d]\
\n Projection: #SUM(__dask_sql_count__4) AS c_d, #COUNT(a.d) AS cd_d\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__4), COUNT(#a.d)]]\
\n Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(#a.d) AS __dask_sql_count__4]]\
\n TableScan: a projection=[a, b, c, d]";
assert_fully_optimized_plan_eq(&plan, expected);

Ok(())
}
}
42 changes: 41 additions & 1 deletion tests/integration/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,52 @@ def test_join_multi():
)


def test_single_agg_count_no_group_by():
a = make_rand_df(
100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)
)
eq_sqlite(
"""
SELECT
COUNT(a) AS c_a,
COUNT(DISTINCT a) AS cd_a
FROM a
""",
a=a,
)


def test_multi_agg_count_no_group_by():
a = make_rand_df(
100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)
)
eq_sqlite(
"""
SELECT
COUNT(a) AS c_a,
COUNT(DISTINCT a) AS cd_a,
COUNT(b) AS c_b,
COUNT(DISTINCT b) AS cd_b,
COUNT(c) AS c_c,
COUNT(DISTINCT c) AS cd_c,
COUNT(d) AS c_d,
COUNT(DISTINCT d) AS cd_d,
COUNT(e) AS c_e,
COUNT(DISTINCT e) AS cd_e
FROM a
""",
a=a,
)


@pytest.mark.skip(
reason="conflicting aggregation functions: [('count', 'a'), ('count', 'a')]"
)
def test_multi_agg_count_no_group_by():
def test_multi_agg_count_no_group_by_dupe_distinct():
a = make_rand_df(
100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)
)
# note that this test repeats the expression `COUNT(DISTINCT a)`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to open an issue to follow up on the dupe case later on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done: #787

eq_sqlite(
"""
SELECT
Expand Down Expand Up @@ -354,6 +393,7 @@ def test_agg_count():
a = make_rand_df(
100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)
)
# note that this test repeats the expression `COUNT(DISTINCT a)`
eq_sqlite(
"""
SELECT
Expand Down