diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index f1732d46f..9c85b5324 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -528,7 +528,21 @@ impl PyExpr { #[pyo3(name = "getFilterExpr")] pub fn get_filter_expr(&self) -> PyResult> { + // TODO refactor to avoid duplication match &self.expr { + Expr::Alias(expr, _) => match expr.as_ref() { + Expr::AggregateFunction { filter, .. } | Expr::AggregateUDF { filter, .. } => { + match filter { + Some(filter) => { + Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone()))) + } + None => Ok(None), + } + } + _ => Err(py_type_err( + "getFilterExpr() - Non-aggregate expression encountered", + )), + }, Expr::AggregateFunction { filter, .. } | Expr::AggregateUDF { filter, .. } => { match filter { Some(filter) => { diff --git a/dask_planner/src/sql/optimizer.rs b/dask_planner/src/sql/optimizer.rs index 85a5865ee..ce86e0390 100644 --- a/dask_planner/src/sql/optimizer.rs +++ b/dask_planner/src/sql/optimizer.rs @@ -11,6 +11,7 @@ use datafusion_optimizer::{ projection_push_down::ProjectionPushDown, subquery_filter_to_join::SubqueryFilterToJoin, OptimizerConfig, }; +use log::trace; mod eliminate_agg_distinct; use eliminate_agg_distinct::EliminateAggDistinct; @@ -55,7 +56,14 @@ impl DaskSqlOptimizer { let mut resulting_plan: LogicalPlan = plan; for optimization in &self.optimizations { match optimization.optimize(&resulting_plan, &mut OptimizerConfig::new()) { - Ok(optimized_plan) => resulting_plan = optimized_plan, + Ok(optimized_plan) => { + trace!( + "== AFTER APPLYING RULE {} ==\n{}", + optimization.name(), + optimized_plan.display_indent() + ); + resulting_plan = optimized_plan + } Err(e) => { println!( "Skipping optimizer rule {} due to unexpected error: {}", diff --git a/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs b/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs index 14af5d193..803ad67fc 100644 --- a/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs +++ b/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs @@ -63,7 +63,7 @@ //! Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__4]]\ //! TableScan: a -use datafusion_common::{DFSchema, Result, ScalarValue}; +use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::logical_plan::Projection; use datafusion_expr::utils::exprlist_to_fields; use datafusion_expr::{ @@ -136,7 +136,7 @@ impl OptimizerRule for EliminateAggDistinct { .collect::>>()?; for plan in &plans { - trace!("FINAL PLAN:\n{}", plan.display_indent()); + trace!("{}", plan.display_indent()); } match plans.len() { @@ -151,7 +151,9 @@ impl OptimizerRule for EliminateAggDistinct { for plan in plans.iter().skip(1) { builder = builder.cross_join(plan)?; } - builder.build() + let join_plan = builder.build()?; + trace!("{}", join_plan.display_indent_schema()); + Ok(join_plan) } } } @@ -214,7 +216,9 @@ fn create_plan( let mut group_expr = group_expr.clone(); group_expr.push(expr.clone()); let alias = format!("__dask_sql_count__{}", optimizer_config.next_id()); - let aggr_expr = vec![count(Expr::Literal(ScalarValue::UInt64(Some(1)))).alias(&alias)]; + let expr_name = expr.name()?; + let count_expr = Expr::Column(Column::from_qualified_name(&expr_name)); + let aggr_expr = vec![count(count_expr).alias(&alias)]; let mut schema_expr = group_expr.clone(); schema_expr.extend_from_slice(&aggr_expr); let schema = DFSchema::new_with_metadata( @@ -461,6 +465,7 @@ fn unique_set_without_aliases(unique_expressions: &HashSet) -> HashSet