Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 86 additions & 3 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! physical query plans and executed.

use fmt::Debug;
use std::{any::Any, collections::HashSet, fmt, sync::Arc};
use std::{any::Any, collections::HashMap, collections::HashSet, fmt, sync::Arc};

use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
Expand Down Expand Up @@ -1129,7 +1129,12 @@ impl LogicalPlanBuilder {
}))
}

/// Apply a projection
/// Apply a projection.
///
/// # Errors
/// This function errors under any of the following conditions:
/// * Two or more expressions have the same name
/// * An invalid expression is used (e.g. a `sort` expression)
pub fn project(&self, expr: Vec<Expr>) -> Result<Self> {
let input_schema = self.plan.schema();
let mut projected_expr = vec![];
Expand All @@ -1141,6 +1146,8 @@ impl LogicalPlanBuilder {
_ => projected_expr.push(expr[i].clone()),
});

validate_unique_names("Projections", &projected_expr, input_schema)?;

let schema = Schema::new(exprlist_to_fields(&projected_expr, input_schema)?);

Ok(Self::from(&LogicalPlan::Projection {
Expand Down Expand Up @@ -1179,6 +1186,8 @@ impl LogicalPlanBuilder {
let mut all_expr: Vec<Expr> = group_expr.clone();
aggr_expr.iter().for_each(|x| all_expr.push(x.clone()));

validate_unique_names("Aggregations", &all_expr, self.plan.schema())?;

let aggr_schema = Schema::new(exprlist_to_fields(&all_expr, self.plan.schema())?);

Ok(Self::from(&LogicalPlan::Aggregate {
Expand Down Expand Up @@ -1212,6 +1221,33 @@ impl LogicalPlanBuilder {
}
}

/// Errors if one or more expressions have equal names.
fn validate_unique_names(
node_name: &str,
expressions: &[Expr],
input_schema: &Schema,
) -> Result<()> {
let mut unique_names = HashMap::new();
expressions.iter().enumerate().map(|(position, expr)| {
let name = expr.name(input_schema)?;
match unique_names.get(&name) {
None => {
unique_names.insert(name, (position, expr));
Ok(())
},
Some((existing_position, existing_expr)) => {
Err(ExecutionError::General(
format!("{} require unique expression names \
but the expression \"{:?}\" at position {} and \"{:?}\" \
at position {} have the same name. Consider aliasing (\"AS\") one of them.",
node_name, existing_expr, existing_position, expr, position,
)
))
}
}
}).collect::<Result<()>>()
}

/// Represents which type of plan
#[derive(Debug, Clone, PartialEq)]
pub enum PlanType {
Expand Down Expand Up @@ -1333,7 +1369,6 @@ mod tests {
Ok(())
}

#[test]
#[test]
fn plan_builder_sort() -> Result<()> {
let plan = LogicalPlanBuilder::scan(
Expand Down Expand Up @@ -1364,6 +1399,54 @@ mod tests {
Ok(())
}

#[test]
fn projection_non_unique_names() -> Result<()> {
let plan = LogicalPlanBuilder::scan(
"default",
"employee.csv",
&employee_schema(),
Some(vec![0, 3]),
)?
// two columns with the same name => error
.project(vec![col("id"), col("first_name").alias("id")]);

match plan {
Err(ExecutionError::General(e)) => {
assert_eq!(e, "Projections require unique expression names \
but the expression \"#id\" at position 0 and \"#first_name AS id\" at \
position 1 have the same name. Consider aliasing (\"AS\") one of them.");
Ok(())
}
_ => Err(ExecutionError::General(
"Plan should have returned an ExecutionError::General".to_string(),
)),
}
}

#[test]
fn aggregate_non_unique_names() -> Result<()> {
let plan = LogicalPlanBuilder::scan(
"default",
"employee.csv",
&employee_schema(),
Some(vec![0, 3]),
)?
// two columns with the same name => error
.aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]);

match plan {
Err(ExecutionError::General(e)) => {
assert_eq!(e, "Aggregations require unique expression names \
but the expression \"#state\" at position 0 and \"SUM(#salary) AS state\" at \
position 1 have the same name. Consider aliasing (\"AS\") one of them.");
Ok(())
}
_ => Err(ExecutionError::General(
"Plan should have returned an ExecutionError::General".to_string(),
)),
}
}

fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Expand Down
2 changes: 1 addition & 1 deletion rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async fn parquet_single_nan_schema() {
async fn csv_count_star() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
let sql = "SELECT COUNT(*), COUNT(1), COUNT(c1) FROM aggregate_test_100";
let sql = "SELECT COUNT(*), COUNT(1) AS c, COUNT(c1) FROM aggregate_test_100";
let actual = execute(&mut ctx, sql).await.join("\n");
let expected = "100\t100\t100".to_string();
assert_eq!(expected, actual);
Expand Down