diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index 3335f938544..b9403824289 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -469,18 +469,19 @@ impl DefaultPhysicalPlanner { input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { + // unpack aliased logical expressions, e.g. "sum(col) as total" + let (name, e) = match e { + Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + _ => (e.name(input_schema)?, e), + }; + match e { Expr::AggregateFunction { fun, args, .. } => { let args = args .iter() .map(|e| self.create_physical_expr(e, input_schema, ctx_state)) .collect::>>()?; - aggregates::create_aggregate_expr( - fun, - &args, - input_schema, - e.name(input_schema)?, - ) + aggregates::create_aggregate_expr(fun, &args, input_schema, name) } Expr::AggregateUDF { fun, args, .. } => { let args = args @@ -488,12 +489,7 @@ impl DefaultPhysicalPlanner { .map(|e| self.create_physical_expr(e, input_schema, ctx_state)) .collect::>>()?; - udaf::create_aggregate_expr( - fun, - &args, - input_schema, - e.name(input_schema)?, - ) + udaf::create_aggregate_expr(fun, &args, input_schema, name) } other => Err(ExecutionError::General(format!( "Invalid aggregate expression '{:?}'", diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 08d63513e72..b2ed920a8be 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -591,6 +591,7 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { fn is_aggregate_expr(e: &Expr) -> bool { match e { Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } => true, + Expr::Alias(expr, _) => is_aggregate_expr(expr), _ => false, } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index b4e07b644e8..5e1c0725af8 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -400,6 +400,18 @@ async fn csv_query_group_by_int_count() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_group_with_aliased_aggregate() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT c1, count(c12) AS count FROM aggregate_test_100 GROUP BY c1"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = "\"a\"\t21\n\"b\"\t19\n\"c\"\t21\n\"d\"\t18\n\"e\"\t21".to_string(); + assert_eq!(expected, actual.join("\n")); + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_string_min_max() -> Result<()> { let mut ctx = ExecutionContext::new();