diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 45ad6891866..e6eca199b24 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -40,6 +40,7 @@ use crate::{ }; use arrow::datatypes::*; +use hashbrown::HashMap; use crate::prelude::JoinType; use sqlparser::ast::{ @@ -103,7 +104,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logic plan from an SQL query pub fn query_to_plan(&self, query: &Query) -> Result { - self.query_to_plan_with_alias(query, None) + self.query_to_plan_with_alias(query, None, &mut HashMap::new()) } /// Generate a logic plan from an SQL query with optional alias @@ -111,9 +112,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, query: &Query, alias: Option, + ctes: &mut HashMap, ) -> Result { let set_expr = &query.body; - let plan = self.set_expr_to_plan(set_expr, alias)?; + if let Some(with) = &query.with { + // Process CTEs from top to bottom + // do not allow self-references + for cte in &with.cte_tables { + // create logical plan & pass backreferencing CTEs + let logical_plan = self.query_to_plan_with_alias( + &cte.query, + Some(cte.alias.name.value.clone()), + &mut ctes.clone(), + )?; + ctes.insert(cte.alias.name.value.clone(), logical_plan); + } + } + let plan = self.set_expr_to_plan(set_expr, alias, ctes)?; let plan = self.order_by(&plan, &query.order_by)?; @@ -124,9 +139,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, set_expr: &SetExpr, alias: Option, + ctes: &mut HashMap, ) -> Result { match set_expr { - SetExpr::Select(s) => self.select_to_plan(s.as_ref()), + SetExpr::Select(s) => self.select_to_plan(s.as_ref(), ctes), SetExpr::SetOperation { op, left, @@ -134,8 +150,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { all, } => match (op, all) { (SetOperator::Union, true) => { - let left_plan = self.set_expr_to_plan(left.as_ref(), None)?; - let right_plan = self.set_expr_to_plan(right.as_ref(), None)?; + let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?; + let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?; let inputs = vec![left_plan, right_plan] .into_iter() .flat_map(|p| match p { @@ -279,24 +295,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn plan_from_tables(&self, from: &[TableWithJoins]) -> Result> { + fn plan_from_tables( + &self, + from: &[TableWithJoins], + ctes: &mut HashMap, + ) -> Result> { match from.len() { 0 => Ok(vec![LogicalPlanBuilder::empty(true).build()?]), _ => from .iter() - .map(|t| self.plan_table_with_joins(t)) + .map(|t| self.plan_table_with_joins(t, ctes)) .collect::>>(), } } - fn plan_table_with_joins(&self, t: &TableWithJoins) -> Result { - let left = self.create_relation(&t.relation)?; + fn plan_table_with_joins( + &self, + t: &TableWithJoins, + ctes: &mut HashMap, + ) -> Result { + let left = self.create_relation(&t.relation, ctes)?; match t.joins.len() { 0 => Ok(left), n => { - let mut left = self.parse_relation_join(&left, &t.joins[0])?; + let mut left = self.parse_relation_join(&left, &t.joins[0], ctes)?; for i in 1..n { - left = self.parse_relation_join(&left, &t.joins[i])?; + left = self.parse_relation_join(&left, &t.joins[i], ctes)?; } Ok(left) } @@ -307,8 +331,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, left: &LogicalPlan, join: &Join, + ctes: &mut HashMap, ) -> Result { - let right = self.create_relation(&join.relation)?; + let right = self.create_relation(&join.relation, ctes)?; match &join.join_operator { JoinOperator::LeftOuter(constraint) => { self.parse_join(left, &right, constraint, JoinType::Left) @@ -371,16 +396,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn create_relation(&self, relation: &TableFactor) -> Result { + fn create_relation( + &self, + relation: &TableFactor, + ctes: &mut HashMap, + ) -> Result { match relation { TableFactor::Table { name, .. } => { let table_name = name.to_string(); - match self.schema_provider.get_table_provider(name.try_into()?) { - Some(provider) => { + let cte = ctes.get(&table_name); + match ( + cte, + self.schema_provider.get_table_provider(name.try_into()?), + ) { + (Some(cte_plan), _) => Ok(cte_plan.clone()), + (_, Some(provider)) => { LogicalPlanBuilder::scan(&table_name, provider, None)?.build() } - None => Err(DataFusionError::Plan(format!( - "no provider found for table {}", + (_, None) => Err(DataFusionError::Plan(format!( + "Table or CTE with name '{}' not found", name ))), } @@ -390,9 +424,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } => self.query_to_plan_with_alias( subquery, alias.as_ref().map(|a| a.name.value.to_string()), + ctes, ), TableFactor::NestedJoin(table_with_joins) => { - self.plan_table_with_joins(table_with_joins) + self.plan_table_with_joins(table_with_joins, ctes) } // @todo Support TableFactory::TableFunction? _ => Err(DataFusionError::NotImplemented(format!( @@ -403,8 +438,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } /// Generate a logic plan from an SQL select - fn select_to_plan(&self, select: &Select) -> Result { - let plans = self.plan_from_tables(&select.from)?; + fn select_to_plan( + &self, + select: &Select, + ctes: &mut HashMap, + ) -> Result { + let plans = self.plan_from_tables(&select.from, ctes)?; let plan = match &select.selection { Some(predicate_expr) => { diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index ba8230a442f..66916698626 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1943,6 +1943,73 @@ async fn query_without_from() -> Result<()> { Ok(()) } +#[tokio::test] +async fn query_cte() -> Result<()> { + // Test for SELECT without FROM. + // Should evaluate expressions in project position. + let mut ctx = ExecutionContext::new(); + + // simple with + let sql = "WITH t AS (SELECT 1) SELECT * FROM t"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["1"]]; + assert_eq!(expected, actual); + + // with + union + let sql = "WITH t AS (SELECT 1 AS a), u AS (SELECT 2 AS a) SELECT * FROM t UNION ALL SELECT * FROM u"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["1"], vec!["2"]]; + assert_eq!(expected, actual); + + // with + join + let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT 1 AS id2, 5 as x) SELECT x FROM t JOIN u ON (id1 = id2)"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["5"]]; + assert_eq!(expected, actual); + + // backward reference + let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT * FROM t) SELECT * from u"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["1"]]; + assert_eq!(expected, actual); + + Ok(()) +} + +#[tokio::test] +async fn query_cte_incorrect() -> Result<()> { + let ctx = ExecutionContext::new(); + + // self reference + let sql = "WITH t AS (SELECT * FROM t) SELECT * from u"; + let plan = ctx.create_logical_plan(&sql); + assert!(plan.is_err()); + assert_eq!( + format!("{}", plan.unwrap_err()), + "Error during planning: Table or CTE with name \'t\' not found" + ); + + // forward referencing + let sql = "WITH t AS (SELECT * FROM u), u AS (SELECT 1) SELECT * from u"; + let plan = ctx.create_logical_plan(&sql); + assert!(plan.is_err()); + assert_eq!( + format!("{}", plan.unwrap_err()), + "Error during planning: Table or CTE with name \'u\' not found" + ); + + // wrapping should hide u + let sql = "WITH t AS (WITH u as (SELECT 1) SELECT 1) SELECT * from u"; + let plan = ctx.create_logical_plan(&sql); + assert!(plan.is_err()); + assert_eq!( + format!("{}", plan.unwrap_err()), + "Error during planning: Table or CTE with name \'u\' not found" + ); + + Ok(()) +} + #[tokio::test] async fn query_scalar_minus_array() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)]));