Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 59 additions & 20 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use crate::{
};

use arrow::datatypes::*;
use hashbrown::HashMap;

use crate::prelude::JoinType;
use sqlparser::ast::{
Expand Down Expand Up @@ -103,17 +104,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

/// Generate a logic plan from an SQL query
pub fn query_to_plan(&self, query: &Query) -> Result<LogicalPlan> {
self.query_to_plan_with_alias(query, None)
self.query_to_plan_with_alias(query, None, &mut HashMap::new())
}

/// Generate a logic plan from an SQL query with optional alias
pub fn query_to_plan_with_alias(
&self,
query: &Query,
alias: Option<String>,
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<LogicalPlan> {
let set_expr = &query.body;
let plan = self.set_expr_to_plan(set_expr, alias)?;
if let Some(with) = &query.with {
// Process CTEs from top to bottom
// do not allow self-references
for cte in &with.cte_tables {
// create logical plan & pass backreferencing CTEs
let logical_plan = self.query_to_plan_with_alias(
&cte.query,
Some(cte.alias.name.value.clone()),
&mut ctes.clone(),
Copy link
Contributor Author

@Dandandan Dandandan Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This clone is now relatively inefficient for very long and/or deep ctes.

Solutions I see for this problem.

  • use an immutable HashMap (O(1) clone, easier to program), example: https://docs.rs/im/15.0.0/im/struct.HashMap.html
  • use something like "frames", e.g. Vec<HashMap<String, LogicalPlan>> -> when looking up a reference first look up level 0, level-1, level -2, etc.
  • cleanup the variables before returning (removing / replacing the added references)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the frames idea makes the most sense to me -- namely a stack

We can handle it in the future if/when it shows itself to be a problem

)?;
ctes.insert(cte.alias.name.value.clone(), logical_plan);
}
}
let plan = self.set_expr_to_plan(set_expr, alias, ctes)?;

let plan = self.order_by(&plan, &query.order_by)?;

Expand All @@ -124,18 +139,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
set_expr: &SetExpr,
alias: Option<String>,
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<LogicalPlan> {
match set_expr {
SetExpr::Select(s) => self.select_to_plan(s.as_ref()),
SetExpr::Select(s) => self.select_to_plan(s.as_ref(), ctes),
SetExpr::SetOperation {
op,
left,
right,
all,
} => match (op, all) {
(SetOperator::Union, true) => {
let left_plan = self.set_expr_to_plan(left.as_ref(), None)?;
let right_plan = self.set_expr_to_plan(right.as_ref(), None)?;
let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
let inputs = vec![left_plan, right_plan]
.into_iter()
.flat_map(|p| match p {
Expand Down Expand Up @@ -279,24 +295,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

fn plan_from_tables(&self, from: &[TableWithJoins]) -> Result<Vec<LogicalPlan>> {
fn plan_from_tables(
&self,
from: &[TableWithJoins],
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<Vec<LogicalPlan>> {
match from.len() {
0 => Ok(vec![LogicalPlanBuilder::empty(true).build()?]),
_ => from
.iter()
.map(|t| self.plan_table_with_joins(t))
.map(|t| self.plan_table_with_joins(t, ctes))
.collect::<Result<Vec<_>>>(),
}
}

fn plan_table_with_joins(&self, t: &TableWithJoins) -> Result<LogicalPlan> {
let left = self.create_relation(&t.relation)?;
fn plan_table_with_joins(
&self,
t: &TableWithJoins,
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<LogicalPlan> {
let left = self.create_relation(&t.relation, ctes)?;
match t.joins.len() {
0 => Ok(left),
n => {
let mut left = self.parse_relation_join(&left, &t.joins[0])?;
let mut left = self.parse_relation_join(&left, &t.joins[0], ctes)?;
for i in 1..n {
left = self.parse_relation_join(&left, &t.joins[i])?;
left = self.parse_relation_join(&left, &t.joins[i], ctes)?;
}
Ok(left)
}
Expand All @@ -307,8 +331,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
left: &LogicalPlan,
join: &Join,
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<LogicalPlan> {
let right = self.create_relation(&join.relation)?;
let right = self.create_relation(&join.relation, ctes)?;
match &join.join_operator {
JoinOperator::LeftOuter(constraint) => {
self.parse_join(left, &right, constraint, JoinType::Left)
Expand Down Expand Up @@ -371,16 +396,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

fn create_relation(&self, relation: &TableFactor) -> Result<LogicalPlan> {
fn create_relation(
&self,
relation: &TableFactor,
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<LogicalPlan> {
match relation {
TableFactor::Table { name, .. } => {
let table_name = name.to_string();
match self.schema_provider.get_table_provider(name.try_into()?) {
Some(provider) => {
let cte = ctes.get(&table_name);
match (
cte,
self.schema_provider.get_table_provider(name.try_into()?),
) {
(Some(cte_plan), _) => Ok(cte_plan.clone()),
(_, Some(provider)) => {
LogicalPlanBuilder::scan(&table_name, provider, None)?.build()
}
None => Err(DataFusionError::Plan(format!(
"no provider found for table {}",
(_, None) => Err(DataFusionError::Plan(format!(
"Table or CTE with name '{}' not found",
name
))),
}
Expand All @@ -390,9 +424,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
} => self.query_to_plan_with_alias(
subquery,
alias.as_ref().map(|a| a.name.value.to_string()),
ctes,
),
TableFactor::NestedJoin(table_with_joins) => {
self.plan_table_with_joins(table_with_joins)
self.plan_table_with_joins(table_with_joins, ctes)
}
// @todo Support TableFactory::TableFunction?
_ => Err(DataFusionError::NotImplemented(format!(
Expand All @@ -403,8 +438,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

/// Generate a logic plan from an SQL select
fn select_to_plan(&self, select: &Select) -> Result<LogicalPlan> {
let plans = self.plan_from_tables(&select.from)?;
fn select_to_plan(
&self,
select: &Select,
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<LogicalPlan> {
let plans = self.plan_from_tables(&select.from, ctes)?;

let plan = match &select.selection {
Some(predicate_expr) => {
Expand Down
67 changes: 67 additions & 0 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1943,6 +1943,73 @@ async fn query_without_from() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn query_cte() -> Result<()> {
// Test for SELECT <expression> without FROM.
// Should evaluate expressions in project position.
let mut ctx = ExecutionContext::new();

// simple with
let sql = "WITH t AS (SELECT 1) SELECT * FROM t";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["1"]];
assert_eq!(expected, actual);

// with + union
let sql = "WITH t AS (SELECT 1 AS a), u AS (SELECT 2 AS a) SELECT * FROM t UNION ALL SELECT * FROM u";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["1"], vec!["2"]];
assert_eq!(expected, actual);

// with + join
let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT 1 AS id2, 5 as x) SELECT x FROM t JOIN u ON (id1 = id2)";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["5"]];
assert_eq!(expected, actual);

// backward reference
let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT * FROM t) SELECT * from u";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["1"]];
assert_eq!(expected, actual);

Ok(())
}

#[tokio::test]
async fn query_cte_incorrect() -> Result<()> {
let ctx = ExecutionContext::new();

// self reference
let sql = "WITH t AS (SELECT * FROM t) SELECT * from u";
let plan = ctx.create_logical_plan(&sql);
assert!(plan.is_err());
assert_eq!(
format!("{}", plan.unwrap_err()),
"Error during planning: Table or CTE with name \'t\' not found"
);

// forward referencing
let sql = "WITH t AS (SELECT * FROM u), u AS (SELECT 1) SELECT * from u";
let plan = ctx.create_logical_plan(&sql);
assert!(plan.is_err());
assert_eq!(
format!("{}", plan.unwrap_err()),
"Error during planning: Table or CTE with name \'u\' not found"
);

// wrapping should hide u
let sql = "WITH t AS (WITH u as (SELECT 1) SELECT 1) SELECT * from u";
let plan = ctx.create_logical_plan(&sql);
assert!(plan.is_err());
assert_eq!(
format!("{}", plan.unwrap_err()),
"Error during planning: Table or CTE with name \'u\' not found"
);

Ok(())
}

#[tokio::test]
async fn query_scalar_minus_array() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)]));
Expand Down