diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 8d21e381e25..0550741a9f9 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -22,7 +22,7 @@ //! physical query plans and executed. use fmt::Debug; -use std::{any::Any, collections::HashSet, fmt, sync::Arc}; +use std::{any::Any, collections::HashMap, collections::HashSet, fmt, sync::Arc}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -1129,7 +1129,12 @@ impl LogicalPlanBuilder { })) } - /// Apply a projection + /// Apply a projection. + /// + /// # Errors + /// This function errors under any of the following conditions: + /// * Two or more expressions have the same name + /// * An invalid expression is used (e.g. a `sort` expression) pub fn project(&self, expr: Vec) -> Result { let input_schema = self.plan.schema(); let mut projected_expr = vec![]; @@ -1141,6 +1146,8 @@ impl LogicalPlanBuilder { _ => projected_expr.push(expr[i].clone()), }); + validate_unique_names("Projections", &projected_expr, input_schema)?; + let schema = Schema::new(exprlist_to_fields(&projected_expr, input_schema)?); Ok(Self::from(&LogicalPlan::Projection { @@ -1179,6 +1186,8 @@ impl LogicalPlanBuilder { let mut all_expr: Vec = group_expr.clone(); aggr_expr.iter().for_each(|x| all_expr.push(x.clone())); + validate_unique_names("Aggregations", &all_expr, self.plan.schema())?; + let aggr_schema = Schema::new(exprlist_to_fields(&all_expr, self.plan.schema())?); Ok(Self::from(&LogicalPlan::Aggregate { @@ -1212,6 +1221,33 @@ impl LogicalPlanBuilder { } } +/// Errors if one or more expressions have equal names. +fn validate_unique_names( + node_name: &str, + expressions: &[Expr], + input_schema: &Schema, +) -> Result<()> { + let mut unique_names = HashMap::new(); + expressions.iter().enumerate().map(|(position, expr)| { + let name = expr.name(input_schema)?; + match unique_names.get(&name) { + None => { + unique_names.insert(name, (position, expr)); + Ok(()) + }, + Some((existing_position, existing_expr)) => { + Err(ExecutionError::General( + format!("{} require unique expression names \ + but the expression \"{:?}\" at position {} and \"{:?}\" \ + at position {} have the same name. Consider aliasing (\"AS\") one of them.", + node_name, existing_expr, existing_position, expr, position, + ) + )) + } + } + }).collect::>() +} + /// Represents which type of plan #[derive(Debug, Clone, PartialEq)] pub enum PlanType { @@ -1333,7 +1369,6 @@ mod tests { Ok(()) } - #[test] #[test] fn plan_builder_sort() -> Result<()> { let plan = LogicalPlanBuilder::scan( @@ -1364,6 +1399,54 @@ mod tests { Ok(()) } + #[test] + fn projection_non_unique_names() -> Result<()> { + let plan = LogicalPlanBuilder::scan( + "default", + "employee.csv", + &employee_schema(), + Some(vec![0, 3]), + )? + // two columns with the same name => error + .project(vec![col("id"), col("first_name").alias("id")]); + + match plan { + Err(ExecutionError::General(e)) => { + assert_eq!(e, "Projections require unique expression names \ + but the expression \"#id\" at position 0 and \"#first_name AS id\" at \ + position 1 have the same name. Consider aliasing (\"AS\") one of them."); + Ok(()) + } + _ => Err(ExecutionError::General( + "Plan should have returned an ExecutionError::General".to_string(), + )), + } + } + + #[test] + fn aggregate_non_unique_names() -> Result<()> { + let plan = LogicalPlanBuilder::scan( + "default", + "employee.csv", + &employee_schema(), + Some(vec![0, 3]), + )? + // two columns with the same name => error + .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); + + match plan { + Err(ExecutionError::General(e)) => { + assert_eq!(e, "Aggregations require unique expression names \ + but the expression \"#state\" at position 0 and \"SUM(#salary) AS state\" at \ + position 1 have the same name. Consider aliasing (\"AS\") one of them."); + Ok(()) + } + _ => Err(ExecutionError::General( + "Plan should have returned an ExecutionError::General".to_string(), + )), + } + } + fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 6c5e1e8da8d..841ee353f98 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -121,7 +121,7 @@ async fn parquet_single_nan_schema() { async fn csv_count_star() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; - let sql = "SELECT COUNT(*), COUNT(1), COUNT(c1) FROM aggregate_test_100"; + let sql = "SELECT COUNT(*), COUNT(1) AS c, COUNT(c1) FROM aggregate_test_100"; let actual = execute(&mut ctx, sql).await.join("\n"); let expected = "100\t100\t100".to_string(); assert_eq!(expected, actual);