Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions rust/datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,31 +469,27 @@ impl DefaultPhysicalPlanner {
input_schema: &Schema,
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn AggregateExpr>> {
// unpack aliased logical expressions, e.g. "sum(col) as total"
let (name, e) = match e {
Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
_ => (e.name(input_schema)?, e),
};

match e {
Expr::AggregateFunction { fun, args, .. } => {
let args = args
.iter()
.map(|e| self.create_physical_expr(e, input_schema, ctx_state))
.collect::<Result<Vec<_>>>()?;
aggregates::create_aggregate_expr(
fun,
&args,
input_schema,
e.name(input_schema)?,
)
aggregates::create_aggregate_expr(fun, &args, input_schema, name)
}
Expr::AggregateUDF { fun, args, .. } => {
let args = args
.iter()
.map(|e| self.create_physical_expr(e, input_schema, ctx_state))
.collect::<Result<Vec<_>>>()?;

udaf::create_aggregate_expr(
fun,
&args,
input_schema,
e.name(input_schema)?,
)
udaf::create_aggregate_expr(fun, &args, input_schema, name)
}
other => Err(ExecutionError::General(format!(
"Invalid aggregate expression '{:?}'",
Expand Down
1 change: 1 addition & 0 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
fn is_aggregate_expr(e: &Expr) -> bool {
match e {
Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } => true,
Expr::Alias(expr, _) => is_aggregate_expr(expr),
_ => false,
}
}
Expand Down
12 changes: 12 additions & 0 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,18 @@ async fn csv_query_group_by_int_count() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_group_with_aliased_aggregate() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
let sql = "SELECT c1, count(c12) AS count FROM aggregate_test_100 GROUP BY c1";
let mut actual = execute(&mut ctx, sql).await;
actual.sort();
let expected = "\"a\"\t21\n\"b\"\t19\n\"c\"\t21\n\"d\"\t18\n\"e\"\t21".to_string();
assert_eq!(expected, actual.join("\n"));
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_string_min_max() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down