Skip to content

Commit

Permalink
Planner: support LATERAL subqueries
Browse files Browse the repository at this point in the history
  • Loading branch information
aalexandrov committed Jul 14, 2024
1 parent a7041fe commit ee34089
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 14 deletions.
27 changes: 27 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ pub struct PlannerContext {
ctes: HashMap<String, Arc<LogicalPlan>>,
/// The query schema of the outer query plan, used to resolve the columns in subquery
outer_query_schema: Option<DFSchemaRef>,
/// The joined schemas of all FROM clauses planned so far. When planning LATERAL
/// FROM clauses, this should become a suffix of the `outer_query_schema`.
outer_from_schema: Option<DFSchemaRef>,
}

impl Default for PlannerContext {
Expand All @@ -124,6 +127,7 @@ impl PlannerContext {
prepare_param_data_types: Arc::new(vec![]),
ctes: HashMap::new(),
outer_query_schema: None,
outer_from_schema: None,
}
}

Expand Down Expand Up @@ -151,6 +155,29 @@ impl PlannerContext {
schema
}

// return a clone of the outer FROM schema
pub fn outer_from_schema(&self) -> Option<Arc<DFSchema>> {
self.outer_from_schema.clone()
}

/// sets the outer FROM schema, returning the existing one, if any
pub fn set_outer_from_schema(
&mut self,
mut schema: Option<DFSchemaRef>,
) -> Option<DFSchemaRef> {
std::mem::swap(&mut self.outer_from_schema, &mut schema);
schema
}

/// extends the FROM schema, returning the existing one, if any
pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> {
self.outer_from_schema = match self.outer_from_schema.as_ref() {
Some(from_schema) => Some(Arc::new(from_schema.join(schema)?)),
None => Some(Arc::clone(schema)),
};
Ok(())
}

/// Return the types of parameters (`$1`, `$2`, etc) if known
pub fn prepare_param_data_types(&self) -> &[DataType] {
&self.prepare_param_data_types
Expand Down
46 changes: 42 additions & 4 deletions datafusion/sql/src/relation/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{not_impl_err, Column, Result};
use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder};
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableFactor, TableWithJoins};
use std::collections::HashSet;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expand All @@ -27,10 +27,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
t: TableWithJoins,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let mut left = self.create_relation(t.relation, planner_context)?;
for join in t.joins.into_iter() {
let mut left = if is_lateral(&t.relation) {
self.create_relation_subquery(t.relation, planner_context)?
} else {
self.create_relation(t.relation, planner_context)?
};
let old_outer_from_schema = planner_context.outer_from_schema();
for join in t.joins {
planner_context.extend_outer_from_schema(left.schema())?;
left = self.parse_relation_join(left, join, planner_context)?;
}
planner_context.set_outer_from_schema(old_outer_from_schema);
Ok(left)
}

Expand All @@ -40,7 +47,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
join: Join,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let right = self.create_relation(join.relation, planner_context)?;
let right = if is_lateral_join(&join)? {
self.create_relation_subquery(join.relation, planner_context)?
} else {
self.create_relation(join.relation, planner_context)?
};
match join.join_operator {
JoinOperator::LeftOuter(constraint) => {
self.parse_join(left, right, constraint, JoinType::Left, planner_context)
Expand Down Expand Up @@ -144,3 +155,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
}

/// Return `true` iff the given [`TableFactor`] is lateral.
pub(crate) fn is_lateral(factor: &TableFactor) -> bool {
match factor {
TableFactor::Derived { lateral, .. } => *lateral,
TableFactor::Function { lateral, .. } => *lateral,
_ => false,
}
}

/// Return `true` iff the given [`Join`] is lateral.
pub(crate) fn is_lateral_join(join: &Join) -> Result<bool> {
let is_lateral_syntax = is_lateral(&join.relation);
let is_apply_syntax = match join.join_operator {
JoinOperator::FullOuter(..)
| JoinOperator::RightOuter(..)
| JoinOperator::RightAnti(..)
| JoinOperator::RightSemi(..)
if is_lateral_syntax =>
{
return not_impl_err!("NONE constraint is not supported");
}
JoinOperator::CrossApply | JoinOperator::OuterApply => true,
_ => false,
};
Ok(is_lateral_syntax || is_apply_syntax)
}
45 changes: 45 additions & 0 deletions datafusion/sql/src/relation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{not_impl_err, plan_err, DFSchema, Result, TableReference};
use datafusion_expr::builder::subquery_alias;
use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder};
use datafusion_expr::{Subquery, SubqueryAlias};
use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor};

mod join;
Expand Down Expand Up @@ -143,4 +147,45 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(plan)
}
}

pub(crate) fn create_relation_subquery(
&self,
subquery: TableFactor,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
// At this point for a syntacitally valid query the outer_from_schema is
// guaranteed to be set, so the `.unwrap()` call will never panic. This
// is the case because we only call this method for lateral table
// factors, and those can never be the first factor in a FROM list. This
// means we arrived here through the `for` loop in `plan_from_tables` or
// the `for` loop in `plan_table_with_joins`.
let old_from_schema = planner_context.set_outer_from_schema(None).unwrap();
let new_query_schema = match planner_context.outer_query_schema() {
Some(lhs) => Some(Arc::new(lhs.join(&old_from_schema)?)),
None => Some(Arc::clone(&old_from_schema)),
};
let old_query_schema = planner_context.set_outer_query_schema(new_query_schema);

let plan = self.create_relation(subquery, planner_context)?;
let outer_ref_columns = plan.all_out_ref_exprs();

planner_context.set_outer_query_schema(old_query_schema);
planner_context.set_outer_from_schema(Some(old_from_schema));

match plan {
LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => {
subquery_alias(
LogicalPlan::Subquery(Subquery {
subquery: input,
outer_ref_columns,
}),
alias,
)
}
plan => Ok(LogicalPlan::Subquery(Subquery {
subquery: Arc::new(plan),
outer_ref_columns,
})),
}
}
}
31 changes: 21 additions & 10 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,19 +396,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
match from.len() {
0 => Ok(LogicalPlanBuilder::empty(true).build()?),
1 => {
let from = from.remove(0);
self.plan_table_with_joins(from, planner_context)
let input = from.remove(0);
self.plan_table_with_joins(input, planner_context)
}
_ => {
let mut plans = from
.into_iter()
.map(|t| self.plan_table_with_joins(t, planner_context));

let mut left = LogicalPlanBuilder::from(plans.next().unwrap()?);

for right in plans {
left = left.cross_join(right?)?;
let mut from = from.into_iter();

let mut left = LogicalPlanBuilder::from({
let input = from.next().unwrap();
self.plan_table_with_joins(input, planner_context)?
});
let old_outer_from_schema = {
let left_schema = Some(Arc::clone(left.schema()));
planner_context.set_outer_from_schema(left_schema)
};
for input in from {
// Join `input` with the current result (`left`).
let right = self.plan_table_with_joins(input, planner_context)?;
left = left.cross_join(right)?;
// Update the outer FROM schema.
let left_schema = Some(Arc::clone(left.schema()));
planner_context.set_outer_from_schema(left_schema);
}
planner_context.set_outer_from_schema(old_outer_from_schema);

Ok(left.build()?)
}
}
Expand Down
95 changes: 95 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3177,6 +3177,101 @@ fn join_on_complex_condition() {
quick_test(sql, expected);
}

#[test]
fn lateral_comma_join() {
let sql = "SELECT j1_string, j2_string FROM
j1, \
LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2";
let expected = "Projection: j1.j1_string, j2.j2_string\
\n CrossJoin:\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: j2.j2_id, j2.j2_string\
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
\n TableScan: j2";
quick_test(sql, expected);
}

#[test]
fn lateral_comma_join_referencing_join_rhs() {
let sql = "SELECT * FROM\
\n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\
\n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;";
let expected = "Projection: j1.j1_id, j1.j1_string, j2.j2_id, j2.j2_string, j3.j3_id, j3.j3_string, j4.j3_id, j4.j3_string\
\n CrossJoin:\
\n Inner Join: Filter: j1.j1_id = j2.j2_id\
\n TableScan: j1\
\n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\
\n TableScan: j2\
\n TableScan: j3\
\n SubqueryAlias: j4\
\n Subquery:\
\n Projection: j3.j3_id, j3.j3_string\
\n Filter: j3.j3_string = outer_ref(j2.j2_string)\
\n TableScan: j3";
quick_test(sql, expected);
}

#[test]
fn lateral_comma_join_with_shadowing() {
// The j1_id on line 3 references the (closest) j1 definition from line 2.
let sql = "-- Triple nested correlated queries queries\
\nSELECT * FROM j1, LATERAL ( -- line 1\
\n SELECT * FROM j1, LATERAL ( -- line 2\
\n SELECT * FROM j2 WHERE j1.j1_id = j2_id -- line 3\
\n ) as j2\
\n) as j2;";
let expected = "Projection: j1.j1_id, j1.j1_string, j2.j1_id, j2.j1_string, j2.j2_id, j2.j2_string\
\n CrossJoin:\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: j1.j1_id, j1.j1_string, j2.j2_id, j2.j2_string\
\n CrossJoin:\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: j2.j2_id, j2.j2_string\
\n Filter: outer_ref(j1.j1_id) = j2.j2_id\
\n TableScan: j2";
quick_test(sql, expected);
}

#[test]
fn lateral_left_join() {
let sql = "SELECT j1_string, j2_string FROM
j1 \
LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true);";
let expected = "Projection: j1.j1_string, j2.j2_string\
\n Left Join: Filter: Boolean(true)\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: j2.j2_id, j2.j2_string\
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
\n TableScan: j2";
quick_test(sql, expected);
}

#[test]
fn lateral_nested_left_join() {
let sql = "SELECT * FROM
j1, \
(j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))";
let expected = "Projection: j1.j1_id, j1.j1_string, j2.j2_id, j2.j2_string, j3.j3_id, j3.j3_string\
\n CrossJoin:\
\n TableScan: j1\
\n Left Join: Filter: Boolean(true)\
\n TableScan: j2\
\n SubqueryAlias: j3\
\n Subquery:\
\n Projection: j3.j3_id, j3.j3_string\
\n Filter: outer_ref(j1.j1_id) + outer_ref(j2.j2_id) = j3.j3_id\
\n TableScan: j3";
quick_test(sql, expected);
}

#[test]
fn hive_aggregate_with_filter() -> Result<()> {
let dialect = &HiveDialect {};
Expand Down

0 comments on commit ee34089

Please sign in to comment.