Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sql-ref-syntax-qry-select.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ SELECT [ hints , ... ] [ ALL | DISTINCT ] { [ [ named_expression | regex_column_
Specifies a source of input for the query. It can be one of the following:
* Table relation
* [Join relation](sql-ref-syntax-qry-select-join.html)
* [Pivot relation](sql-ref-syntax-qry-select-pivot.md)
* [Unpivot relation](sql-ref-syntax-qry-select-unpivot.md)
* [Table-value function](sql-ref-syntax-qry-select-tvf.html)
* [Inline table](sql-ref-syntax-qry-select-inline-table.html)
* [ [LATERAL](sql-ref-syntax-qry-select-lateral-subquery.html) ] ( Subquery )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,13 @@ setQuantifier
;

relation
: LATERAL? relationPrimary joinRelation*
: LATERAL? relationPrimary relationExtension*
;

relationExtension
: joinRelation
| pivotClause
| unpivotClause
;

joinRelation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
Join(left, right, Inner, None, JoinHint.NONE)
}
}
if (conf.ansiRelationPrecedence) join else withJoinRelations(join, relation)
if (conf.ansiRelationPrecedence) join else withRelationExtensions(relation, join)
}
if (ctx.pivotClause() != null) {
if (ctx.unpivotClause() != null) {
Expand Down Expand Up @@ -1263,60 +1263,71 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
* }}}
*/
override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
withJoinRelations(plan(ctx.relationPrimary), ctx)
withRelationExtensions(ctx, plan(ctx.relationPrimary))
}

private def withRelationExtensions(ctx: RelationContext, query: LogicalPlan): LogicalPlan = {
ctx.relationExtension().asScala.foldLeft(query) { (left, extension) =>
if (extension.joinRelation() != null) {
withJoinRelation(extension.joinRelation(), left)
} else if (extension.pivotClause() != null) {
withPivot(extension.pivotClause(), left)
} else {
assert(extension.unpivotClause() != null)
withUnpivot(extension.unpivotClause(), left)
}
}
}

/**
* Join one more [[LogicalPlan]]s to the current logical plan.
* Join one more [[LogicalPlan]] to the current logical plan.
*/
private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
Copy link
Contributor Author

@cloud-fan cloud-fan Nov 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual code change is very small, just remove this loop and rename a few variables.

withOrigin(join) {
val baseJoinType = join.joinType match {
case null => Inner
case jt if jt.CROSS != null => Cross
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
case jt if jt.ANTI != null => LeftAnti
case jt if jt.LEFT != null => LeftOuter
case jt if jt.RIGHT != null => RightOuter
case _ => Inner
}
private def withJoinRelation(ctx: JoinRelationContext, base: LogicalPlan): LogicalPlan = {
withOrigin(ctx) {
val baseJoinType = ctx.joinType match {
case null => Inner
case jt if jt.CROSS != null => Cross
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
case jt if jt.ANTI != null => LeftAnti
case jt if jt.LEFT != null => LeftOuter
case jt if jt.RIGHT != null => RightOuter
case _ => Inner
}

if (join.LATERAL != null && !join.right.isInstanceOf[AliasedQueryContext]) {
throw QueryParsingErrors.invalidLateralJoinRelationError(join.right)
}
if (ctx.LATERAL != null && !ctx.right.isInstanceOf[AliasedQueryContext]) {
throw QueryParsingErrors.invalidLateralJoinRelationError(ctx.right)
}

// Resolve the join type and join condition
val (joinType, condition) = Option(join.joinCriteria) match {
case Some(c) if c.USING != null =>
if (join.LATERAL != null) {
throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx)
}
(UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None)
case Some(c) if c.booleanExpression != null =>
(baseJoinType, Option(expression(c.booleanExpression)))
case Some(c) =>
throw new IllegalStateException(s"Unimplemented joinCriteria: $c")
case None if join.NATURAL != null =>
if (join.LATERAL != null) {
throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx)
}
if (baseJoinType == Cross) {
throw QueryParsingErrors.naturalCrossJoinUnsupportedError(ctx)
}
(NaturalJoin(baseJoinType), None)
case None =>
(baseJoinType, None)
}
if (join.LATERAL != null) {
if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) {
throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql)
// Resolve the join type and join condition
val (joinType, condition) = Option(ctx.joinCriteria) match {
case Some(c) if c.USING != null =>
if (ctx.LATERAL != null) {
throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx)
}
LateralJoin(left, LateralSubquery(plan(join.right)), joinType, condition)
} else {
Join(left, plan(join.right), joinType, condition, JoinHint.NONE)
(UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None)
case Some(c) if c.booleanExpression != null =>
(baseJoinType, Option(expression(c.booleanExpression)))
case Some(c) =>
throw new IllegalStateException(s"Unimplemented joinCriteria: $c")
case None if ctx.NATURAL != null =>
if (ctx.LATERAL != null) {
throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx)
}
if (baseJoinType == Cross) {
throw QueryParsingErrors.naturalCrossJoinUnsupportedError(ctx)
}
(NaturalJoin(baseJoinType), None)
case None =>
(baseJoinType, None)
}
if (ctx.LATERAL != null) {
if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) {
throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql)
}
LateralJoin(base, LateralSubquery(plan(ctx.right)), joinType, condition)
} else {
Join(base, plan(ctx.right), joinType, condition, JoinHint.NONE)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase {
ctx)
}

def unpivotWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = {
def unpivotWithPivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = {
new ParseException("UNPIVOT cannot be used together with PIVOT in FROM clause", ctx)
}

def lateralWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = {
def lateralWithPivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = {
new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0013", ctx)
}

def lateralWithUnpivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = {
def lateralWithUnpivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = {
new ParseException("LATERAL cannot be used together with UNPIVOT in FROM clause", ctx)
}

Expand Down Expand Up @@ -164,7 +164,7 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase {
ctx)
}

def naturalCrossJoinUnsupportedError(ctx: RelationContext): Throwable = {
def naturalCrossJoinUnsupportedError(ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "UNSUPPORTED_FEATURE.NATURAL_CROSS_JOIN",
messageParameters = Map.empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.parser

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Unpivot}
import org.apache.spark.sql.internal.SQLConf

class UnpivotParserSuite extends AnalysisTest {

Expand Down Expand Up @@ -192,4 +193,151 @@ class UnpivotParserSuite extends AnalysisTest {
)
}

test("unpivot - with joins") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add tests for pivot because:

  1. there is no pivot parser suite
  2. pivot/unpivot syntax is exactly the same regarding joins, no need to test both

// unpivot the left table
assertEqual(
"SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) JOIN t2",
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t1")
).where(coalesce($"val").isNotNull).join(table("t2")).select(star()))

// unpivot the join result
assertEqual(
"SELECT * FROM t1 JOIN t2 UNPIVOT (val FOR col in (a, b))",
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t1").join(table("t2"))
).where(coalesce($"val").isNotNull).select(star()))

// unpivot the right table
assertEqual(
"SELECT * FROM t1 JOIN (t2 UNPIVOT (val FOR col in (a, b)))",
table("t1").join(
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t2")
).where(coalesce($"val").isNotNull)
).select(star()))
}

test("unpivot - with implicit joins") {
// unpivot the left table
assertEqual(
"SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)), t2",
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t1")
).where(coalesce($"val").isNotNull).join(table("t2")).select(star()))

// unpivot the join result
assertEqual(
"SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))",
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t1").join(table("t2"))
).where(coalesce($"val").isNotNull).select(star()))

// unpivot the right table - same SQL as above but with ANSI mode
withSQLConf(
SQLConf.ANSI_ENABLED.key -> "true",
SQLConf.ANSI_RELATION_PRECEDENCE.key -> "true") {
assertEqual(
"SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))",
table("t1").join(
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t2")
).where(coalesce($"val").isNotNull)
).select(star()))
}

// unpivot the right table
assertEqual(
"SELECT * FROM t1, (t2 UNPIVOT (val FOR col in (a, b)))",
table("t1").join(
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t2")
).where(coalesce($"val").isNotNull)
).select(star()))

// mixed with explicit joins
assertEqual(
// unpivot the join result of t1, t2 and t3
"SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))",
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t1").join(table("t2")).join(table("t3"))
).where(coalesce($"val").isNotNull).select(star()))
withSQLConf(
SQLConf.ANSI_ENABLED.key -> "true",
SQLConf.ANSI_RELATION_PRECEDENCE.key -> "true") {
assertEqual(
// unpivot the join result of t2 and t3
"SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))",
table("t1").join(
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t2").join(table("t3"))
).where(coalesce($"val").isNotNull)
).select(star()))
}
}

test("unpivot - nested unpivot") {
assertEqual(
"SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) UNPIVOT (val FOR col in (a, b))",
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
Unpivot(
None,
Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))),
None,
"col",
Seq("val"),
table("t1")
).where(coalesce($"val").isNotNull)
).where(coalesce($"val").isNotNull).select(star()))
}
}