diff --git a/docs/sql-ref-syntax-qry-select.md b/docs/sql-ref-syntax-qry-select.md index ea5c4a69d9ab6..22c4d78605b44 100644 --- a/docs/sql-ref-syntax-qry-select.md +++ b/docs/sql-ref-syntax-qry-select.md @@ -83,6 +83,8 @@ SELECT [ hints , ... ] [ ALL | DISTINCT ] { [ [ named_expression | regex_column_ Specifies a source of input for the query. It can be one of the following: * Table relation * [Join relation](sql-ref-syntax-qry-select-join.html) + * [Pivot relation](sql-ref-syntax-qry-select-pivot.md) + * [Unpivot relation](sql-ref-syntax-qry-select-unpivot.md) * [Table-value function](sql-ref-syntax-qry-select-tvf.html) * [Inline table](sql-ref-syntax-qry-select-inline-table.html) * [ [LATERAL](sql-ref-syntax-qry-select-lateral-subquery.html) ] ( Subquery ) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index a3c5f4a7b0709..21747a0a021f2 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -697,7 +697,13 @@ setQuantifier ; relation - : LATERAL? relationPrimary joinRelation* + : LATERAL? relationPrimary relationExtension* + ; + +relationExtension + : joinRelation + | pivotClause + | unpivotClause ; joinRelation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4adb70bc3909f..d56ef28bcc32b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -929,7 +929,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit Join(left, right, Inner, None, JoinHint.NONE) } } - if (conf.ansiRelationPrecedence) join else withJoinRelations(join, relation) + if (conf.ansiRelationPrecedence) join else withRelationExtensions(relation, join) } if (ctx.pivotClause() != null) { if (ctx.unpivotClause() != null) { @@ -1263,60 +1263,71 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit * }}} */ override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { - withJoinRelations(plan(ctx.relationPrimary), ctx) + withRelationExtensions(ctx, plan(ctx.relationPrimary)) + } + + private def withRelationExtensions(ctx: RelationContext, query: LogicalPlan): LogicalPlan = { + ctx.relationExtension().asScala.foldLeft(query) { (left, extension) => + if (extension.joinRelation() != null) { + withJoinRelation(extension.joinRelation(), left) + } else if (extension.pivotClause() != null) { + withPivot(extension.pivotClause(), left) + } else { + assert(extension.unpivotClause() != null) + withUnpivot(extension.unpivotClause(), left) + } + } } /** - * Join one more [[LogicalPlan]]s to the current logical plan. + * Join one more [[LogicalPlan]] to the current logical plan. */ - private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { - ctx.joinRelation.asScala.foldLeft(base) { (left, join) => - withOrigin(join) { - val baseJoinType = join.joinType match { - case null => Inner - case jt if jt.CROSS != null => Cross - case jt if jt.FULL != null => FullOuter - case jt if jt.SEMI != null => LeftSemi - case jt if jt.ANTI != null => LeftAnti - case jt if jt.LEFT != null => LeftOuter - case jt if jt.RIGHT != null => RightOuter - case _ => Inner - } + private def withJoinRelation(ctx: JoinRelationContext, base: LogicalPlan): LogicalPlan = { + withOrigin(ctx) { + val baseJoinType = ctx.joinType match { + case null => Inner + case jt if jt.CROSS != null => Cross + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } - if (join.LATERAL != null && !join.right.isInstanceOf[AliasedQueryContext]) { - throw QueryParsingErrors.invalidLateralJoinRelationError(join.right) - } + if (ctx.LATERAL != null && !ctx.right.isInstanceOf[AliasedQueryContext]) { + throw QueryParsingErrors.invalidLateralJoinRelationError(ctx.right) + } - // Resolve the join type and join condition - val (joinType, condition) = Option(join.joinCriteria) match { - case Some(c) if c.USING != null => - if (join.LATERAL != null) { - throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx) - } - (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) - case Some(c) if c.booleanExpression != null => - (baseJoinType, Option(expression(c.booleanExpression))) - case Some(c) => - throw new IllegalStateException(s"Unimplemented joinCriteria: $c") - case None if join.NATURAL != null => - if (join.LATERAL != null) { - throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx) - } - if (baseJoinType == Cross) { - throw QueryParsingErrors.naturalCrossJoinUnsupportedError(ctx) - } - (NaturalJoin(baseJoinType), None) - case None => - (baseJoinType, None) - } - if (join.LATERAL != null) { - if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { - throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql) + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + if (ctx.LATERAL != null) { + throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx) } - LateralJoin(left, LateralSubquery(plan(join.right)), joinType, condition) - } else { - Join(left, plan(join.right), joinType, condition, JoinHint.NONE) + (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw new IllegalStateException(s"Unimplemented joinCriteria: $c") + case None if ctx.NATURAL != null => + if (ctx.LATERAL != null) { + throw QueryParsingErrors.lateralJoinWithNaturalJoinUnsupportedError(ctx) + } + if (baseJoinType == Cross) { + throw QueryParsingErrors.naturalCrossJoinUnsupportedError(ctx) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + if (ctx.LATERAL != null) { + if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { + throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql) } + LateralJoin(base, LateralSubquery(plan(ctx.right)), joinType, condition) + } else { + Join(base, plan(ctx.right), joinType, condition, JoinHint.NONE) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 8608d4ff306ef..9624a06d80a64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -99,15 +99,15 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase { ctx) } - def unpivotWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = { + def unpivotWithPivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = { new ParseException("UNPIVOT cannot be used together with PIVOT in FROM clause", ctx) } - def lateralWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = { + def lateralWithPivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = { new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0013", ctx) } - def lateralWithUnpivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = { + def lateralWithUnpivotInFromClauseNotAllowedError(ctx: ParserRuleContext): Throwable = { new ParseException("LATERAL cannot be used together with UNPIVOT in FROM clause", ctx) } @@ -164,7 +164,7 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase { ctx) } - def naturalCrossJoinUnsupportedError(ctx: RelationContext): Throwable = { + def naturalCrossJoinUnsupportedError(ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "UNSUPPORTED_FEATURE.NATURAL_CROSS_JOIN", messageParameters = Map.empty, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala index dd7e4ec4916fc..c680e08c1c832 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Unpivot} +import org.apache.spark.sql.internal.SQLConf class UnpivotParserSuite extends AnalysisTest { @@ -192,4 +193,151 @@ class UnpivotParserSuite extends AnalysisTest { ) } + test("unpivot - with joins") { + // unpivot the left table + assertEqual( + "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) JOIN t2", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1") + ).where(coalesce($"val").isNotNull).join(table("t2")).select(star())) + + // unpivot the join result + assertEqual( + "SELECT * FROM t1 JOIN t2 UNPIVOT (val FOR col in (a, b))", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1").join(table("t2")) + ).where(coalesce($"val").isNotNull).select(star())) + + // unpivot the right table + assertEqual( + "SELECT * FROM t1 JOIN (t2 UNPIVOT (val FOR col in (a, b)))", + table("t1").join( + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t2") + ).where(coalesce($"val").isNotNull) + ).select(star())) + } + + test("unpivot - with implicit joins") { + // unpivot the left table + assertEqual( + "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)), t2", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1") + ).where(coalesce($"val").isNotNull).join(table("t2")).select(star())) + + // unpivot the join result + assertEqual( + "SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1").join(table("t2")) + ).where(coalesce($"val").isNotNull).select(star())) + + // unpivot the right table - same SQL as above but with ANSI mode + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + SQLConf.ANSI_RELATION_PRECEDENCE.key -> "true") { + assertEqual( + "SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))", + table("t1").join( + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t2") + ).where(coalesce($"val").isNotNull) + ).select(star())) + } + + // unpivot the right table + assertEqual( + "SELECT * FROM t1, (t2 UNPIVOT (val FOR col in (a, b)))", + table("t1").join( + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t2") + ).where(coalesce($"val").isNotNull) + ).select(star())) + + // mixed with explicit joins + assertEqual( + // unpivot the join result of t1, t2 and t3 + "SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1").join(table("t2")).join(table("t3")) + ).where(coalesce($"val").isNotNull).select(star())) + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + SQLConf.ANSI_RELATION_PRECEDENCE.key -> "true") { + assertEqual( + // unpivot the join result of t2 and t3 + "SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))", + table("t1").join( + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t2").join(table("t3")) + ).where(coalesce($"val").isNotNull) + ).select(star())) + } + } + + test("unpivot - nested unpivot") { + assertEqual( + "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) UNPIVOT (val FOR col in (a, b))", + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + Unpivot( + None, + Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + None, + "col", + Seq("val"), + table("t1") + ).where(coalesce($"val").isNotNull) + ).where(coalesce($"val").isNotNull).select(star())) + } }