From 82978d77796ce8b77178a9dfe37d48cefbbcc5c9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 26 Apr 2017 08:53:51 +0000 Subject: [PATCH 01/21] Set barrier to prevent re-analysis of analyzed plan. --- .../spark/sql/catalyst/trees/TreeNode.scala | 66 ++++++++++++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 29 +++++--- 2 files changed, 76 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cc4c0835954ba..c1083b011c0b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -72,6 +72,34 @@ object CurrentOrigin { } } +case class Barrier(node: Option[TreeNode[_]] = None) + +/** + * Provides a barrier for TreeNodes to prevent transformation from specified nodes. + */ +object CurrentBarrier { + private val value = new ThreadLocal[Barrier]() { + override def initialValue: Barrier = Barrier() + } + + def get: Barrier = value.get() + def set(b: Barrier): Unit = value.set(b) + + def reset(): Unit = value.set(Barrier()) + + def hitBarrier(currentNode: TreeNode[_]): Boolean = { + val barrier = value.get() + barrier.node.isDefined && (barrier.node.get fastEquals currentNode) + } + + def withBarrier[A](b: Barrier)(f: => A): A = { + val barrier = get + set(b) + val ret = try f finally { set(barrier) } + ret + } +} + // scalastyle:off abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // scalastyle:on @@ -115,7 +143,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def foreach(f: BaseType => Unit): Unit = { f(this) - children.foreach(_.foreach(f)) + if (!CurrentBarrier.hitBarrier(this)) { + children.foreach(_.foreach(f)) + } } /** @@ -123,7 +153,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param f the function to be applied to each node in the tree. */ def foreachUp(f: BaseType => Unit): Unit = { - children.foreach(_.foreachUp(f)) + if (!CurrentBarrier.hitBarrier(this)) { + children.foreach(_.foreachUp(f)) + } f(this) } @@ -267,11 +299,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { rule.applyOrElse(this, identity[BaseType]) } - // Check if unchanged and then possibly return old copy to avoid gc churn. - if (this fastEquals afterRule) { - mapChildren(_.transformDown(rule)) + if (CurrentBarrier.hitBarrier(this)) { + if (this fastEquals afterRule) { + this + } else { + afterRule + } } else { - afterRule.mapChildren(_.transformDown(rule)) + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (this fastEquals afterRule) { + mapChildren(_.transformDown(rule)) + } else { + afterRule.mapChildren(_.transformDown(rule)) + } } } @@ -283,14 +323,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param rule the function use to transform this nodes children */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = mapChildren(_.transformUp(rule)) - if (this fastEquals afterRuleOnChildren) { + if (CurrentBarrier.hitBarrier(this)) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) } } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + val afterRuleOnChildren = mapChildren(_.transformUp(rule)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c6dcd93bbda66..076dc7e80ab02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} +import org.apache.spark.sql.catalyst.trees.{Barrier, CurrentBarrier} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command._ @@ -203,7 +204,7 @@ class Dataset[T] private[sql]( * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its * `fromRow` method later. */ - private val boundEnc = + private lazy val boundEnc = exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) private implicit def classTag = exprEnc.clsTag @@ -356,7 +357,11 @@ class Dataset[T] private[sql]( */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema)) + def toDF(): DataFrame = { + CurrentBarrier.withBarrier(Barrier(Some(logicalPlan))) { + new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema)) + } + } /** * :: Experimental :: @@ -2828,21 +2833,27 @@ class Dataset[T] private[sql]( /** A convenient function to wrap a logical plan and produce a DataFrame. */ @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { - Dataset.ofRows(sparkSession, logicalPlan) + CurrentBarrier.withBarrier(Barrier(Some(this.logicalPlan))) { + Dataset.ofRows(sparkSession, logicalPlan) + } } /** A convenient function to wrap a logical plan and produce a Dataset. */ @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { - Dataset(sparkSession, logicalPlan) + CurrentBarrier.withBarrier(Barrier(Some(this.logicalPlan))) { + Dataset(sparkSession, logicalPlan) + } } /** A convenient function to wrap a set based logical plan and produce a Dataset. */ @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { - if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { - // Set operators widen types (change the schema), so we cannot reuse the row encoder. - Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] - } else { - Dataset(sparkSession, logicalPlan) + CurrentBarrier.withBarrier(Barrier(Some(this.logicalPlan))) { + if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { + // Set operators widen types (change the schema), so we cannot reuse the row encoder. + Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] + } else { + Dataset(sparkSession, logicalPlan) + } } } } From 24905e39b249e6ee6cbf6bee7cc859bcac712b76 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 27 Apr 2017 08:12:12 +0000 Subject: [PATCH 02/21] Use a logical node to set analysis barrier. --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++- .../plans/logical/basicLogicalOperators.scala | 34 ++++++++++ .../spark/sql/catalyst/trees/TreeNode.scala | 66 +++---------------- .../scala/org/apache/spark/sql/Dataset.scala | 51 +++++++------- 4 files changed, 75 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dcadbbc90f438..e4f14389d6095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -165,7 +165,8 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases) + CleanupAliases, + CleanupBarriers) ) /** @@ -2435,6 +2436,13 @@ object CleanupAliases extends Rule[LogicalPlan] { } } +/** Remove the barrier nodes of analysis */ +object CleanupBarriers extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case AnalysisBarrier(child) => child + } +} + /** * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to * figure out how many windows a time column can map to, we over-estimate the number of windows and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ad757ebba851..851bff859728b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -906,3 +907,36 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } + +/** A logical plan for setting a barrier of analysis */ +case class AnalysisBarrier(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def analyzed: Boolean = true + + override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + transformBarrier(rule) + } + + override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + transformBarrier(rule) + } + + private def transformBarrier(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } + val childAfterRule = CurrentOrigin.withOrigin(child.origin) { + rule.applyOrElse(child, identity[LogicalPlan]) + } + + if ((child fastEquals childAfterRule) && (this fastEquals afterRule)) { + this + } else if (this fastEquals afterRule) { + AnalysisBarrier(childAfterRule) + } else { + // The only rule that can change barrier node is the rule to remove it. + childAfterRule + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index c1083b011c0b4..cc4c0835954ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -72,34 +72,6 @@ object CurrentOrigin { } } -case class Barrier(node: Option[TreeNode[_]] = None) - -/** - * Provides a barrier for TreeNodes to prevent transformation from specified nodes. - */ -object CurrentBarrier { - private val value = new ThreadLocal[Barrier]() { - override def initialValue: Barrier = Barrier() - } - - def get: Barrier = value.get() - def set(b: Barrier): Unit = value.set(b) - - def reset(): Unit = value.set(Barrier()) - - def hitBarrier(currentNode: TreeNode[_]): Boolean = { - val barrier = value.get() - barrier.node.isDefined && (barrier.node.get fastEquals currentNode) - } - - def withBarrier[A](b: Barrier)(f: => A): A = { - val barrier = get - set(b) - val ret = try f finally { set(barrier) } - ret - } -} - // scalastyle:off abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // scalastyle:on @@ -143,9 +115,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def foreach(f: BaseType => Unit): Unit = { f(this) - if (!CurrentBarrier.hitBarrier(this)) { - children.foreach(_.foreach(f)) - } + children.foreach(_.foreach(f)) } /** @@ -153,9 +123,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param f the function to be applied to each node in the tree. */ def foreachUp(f: BaseType => Unit): Unit = { - if (!CurrentBarrier.hitBarrier(this)) { - children.foreach(_.foreachUp(f)) - } + children.foreach(_.foreachUp(f)) f(this) } @@ -299,19 +267,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { rule.applyOrElse(this, identity[BaseType]) } - if (CurrentBarrier.hitBarrier(this)) { - if (this fastEquals afterRule) { - this - } else { - afterRule - } + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (this fastEquals afterRule) { + mapChildren(_.transformDown(rule)) } else { - // Check if unchanged and then possibly return old copy to avoid gc churn. - if (this fastEquals afterRule) { - mapChildren(_.transformDown(rule)) - } else { - afterRule.mapChildren(_.transformDown(rule)) - } + afterRule.mapChildren(_.transformDown(rule)) } } @@ -323,20 +283,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param rule the function use to transform this nodes children */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - if (CurrentBarrier.hitBarrier(this)) { + val afterRuleOnChildren = mapChildren(_.transformUp(rule)) + if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) } } else { - val afterRuleOnChildren = mapChildren(_.transformUp(rule)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[BaseType]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) - } + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 076dc7e80ab02..0628b37a9675d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -46,7 +46,6 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.trees.{Barrier, CurrentBarrier} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command._ @@ -204,7 +203,7 @@ class Dataset[T] private[sql]( * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its * `fromRow` method later. */ - private lazy val boundEnc = + private val boundEnc = exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) private implicit def classTag = exprEnc.clsTag @@ -358,9 +357,8 @@ class Dataset[T] private[sql]( // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = { - CurrentBarrier.withBarrier(Barrier(Some(logicalPlan))) { - new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema)) - } + val plan = AnalysisBarrier(logicalPlan) + new Dataset[Row](sparkSession, plan, RowEncoder(schema)) } /** @@ -707,7 +705,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + Join(AnalysisBarrier(logicalPlan), right.logicalPlan, joinType = Inner, None) } /** @@ -790,8 +788,8 @@ class Dataset[T] private[sql]( withPlan { Join( - joined.left, - joined.right, + AnalysisBarrier(joined.left), + AnalysisBarrier(joined.right), UsingJoin(JoinType(joinType), usingColumns), None) } @@ -846,8 +844,9 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) - .queryExecution.analyzed.asInstanceOf[Join] + Join(AnalysisBarrier(logicalPlan), right.logicalPlan, JoinType(joinType), + Some(joinExprs.expr))) + .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) { @@ -855,8 +854,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + val lanalyzed = withPlan(AnalysisBarrier(this.logicalPlan)).queryExecution.analyzed + val ranalyzed = withPlan(AnalysisBarrier(right.logicalPlan)).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -888,7 +887,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + Join(AnalysisBarrier(logicalPlan), right.logicalPlan, joinType = Cross, None) } /** @@ -1139,7 +1138,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), logicalPlan) + Project(cols.map(_.named), AnalysisBarrier(logicalPlan)) } /** @@ -1817,7 +1816,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, AnalysisBarrier(logicalPlan)) } } @@ -1858,7 +1857,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, AnalysisBarrier(logicalPlan)) } } @@ -2833,27 +2832,21 @@ class Dataset[T] private[sql]( /** A convenient function to wrap a logical plan and produce a DataFrame. */ @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { - CurrentBarrier.withBarrier(Barrier(Some(this.logicalPlan))) { - Dataset.ofRows(sparkSession, logicalPlan) - } + Dataset.ofRows(sparkSession, logicalPlan) } /** A convenient function to wrap a logical plan and produce a Dataset. */ @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { - CurrentBarrier.withBarrier(Barrier(Some(this.logicalPlan))) { - Dataset(sparkSession, logicalPlan) - } + Dataset(sparkSession, logicalPlan) } /** A convenient function to wrap a set based logical plan and produce a Dataset. */ @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { - CurrentBarrier.withBarrier(Barrier(Some(this.logicalPlan))) { - if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { - // Set operators widen types (change the schema), so we cannot reuse the row encoder. - Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] - } else { - Dataset(sparkSession, logicalPlan) - } + if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { + // Set operators widen types (change the schema), so we cannot reuse the row encoder. + Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] + } else { + Dataset(sparkSession, logicalPlan) } } } From e15b001327786c419e60fa388b1e8ddd42950e31 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 30 Apr 2017 14:50:37 +0000 Subject: [PATCH 03/21] Add test for analysis barrier. --- .../spark/sql/catalyst/analysis/AnalysisSuite.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 893bb1b74cea7..6f876654a397e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -441,4 +441,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } + + test("analysis barrier") { + // [[AnalysisBarrier]] will be removed after analysis + checkAnalysis( + Project(Seq(UnresolvedAttribute("tbl.a")), + AnalysisBarrier(SubqueryAlias("tbl", testRelation))), + Project(testRelation.output, testRelation)) + + // Make sure we won't resolve the plans wrapped in an [[AnalysisBarrier]] + val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), + SubqueryAlias("tbl", testRelation))) + assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) + } } From a076d83cfc9e87f8234eda639957d663d87eaac4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 3 May 2017 03:18:55 +0000 Subject: [PATCH 04/21] Let AnalysisBarrier as LeafNode. --- .../plans/logical/basicLogicalOperators.scala | 29 +------------------ .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 851bff859728b..d2d83b65ff21a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -909,34 +909,7 @@ case class Deduplicate( } /** A logical plan for setting a barrier of analysis */ -case class AnalysisBarrier(child: LogicalPlan) extends UnaryNode { +case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { override def output: Seq[Attribute] = child.output - override def analyzed: Boolean = true - - override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - transformBarrier(rule) - } - - override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - transformBarrier(rule) - } - - private def transformBarrier(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - val afterRule = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } - val childAfterRule = CurrentOrigin.withOrigin(child.origin) { - rule.applyOrElse(child, identity[LogicalPlan]) - } - - if ((child fastEquals childAfterRule) && (this fastEquals afterRule)) { - this - } else if (this fastEquals afterRule) { - AnalysisBarrier(childAfterRule) - } else { - // The only rule that can change barrier node is the rule to remove it. - childAfterRule - } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6f876654a397e..1e7f87882061a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -447,7 +447,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis( Project(Seq(UnresolvedAttribute("tbl.a")), AnalysisBarrier(SubqueryAlias("tbl", testRelation))), - Project(testRelation.output, testRelation)) + Project(testRelation.output, SubqueryAlias("tbl", testRelation))) // Make sure we won't resolve the plans wrapped in an [[AnalysisBarrier]] val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), From b29ded3f806616e43f260db4f133c7bbe3a8fb3b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 5 May 2017 02:23:32 +0000 Subject: [PATCH 05/21] Remove resolveOperators path. --- .../sql/catalyst/analysis/Analyzer.scala | 46 ++++++++++--------- .../catalyst/analysis/DecimalPrecision.scala | 2 +- .../ResolveTableValuedFunctions.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../spark/sql/catalyst/analysis/view.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 27 +---------- .../sql/catalyst/plans/LogicalPlanSuite.scala | 26 +++++------ .../scala/org/apache/spark/sql/Dataset.scala | 34 +++++++------- .../sql/execution/datasources/rules.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 6 +-- 11 files changed, 63 insertions(+), 90 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e4f14389d6095..faada3cf2841d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -173,7 +173,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -201,7 +201,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -243,7 +243,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -615,7 +615,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -670,7 +670,9 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + private def dedupRight (left: LogicalPlan, oriRight: LogicalPlan): LogicalPlan = { + // Remove analysis barrier if any. + val right = CleanupBarriers(oriRight) val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") @@ -787,7 +789,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -962,7 +964,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1007,7 +1009,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1131,7 +1133,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1470,7 +1472,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1485,7 +1487,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1511,7 +1513,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1683,7 +1685,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1741,7 +1743,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2058,7 +2060,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2103,7 +2105,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2168,7 +2170,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2233,7 +2235,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2319,7 +2321,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2353,7 +2355,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2407,7 +2409,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2482,7 +2484,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 9c38dd2ee4e53..ac72bc4ef4200 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { PromotePrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // fix decimal precision for expressions case q => q.transformExpressions( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index de6de24350f23..6070d9cc25b19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e1dd010d37a95..1cc7f32666691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -206,7 +206,7 @@ object TypeCoercion { * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q @@ -261,7 +261,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if p.analyzed => p case s @ SetOperation(left, right) if s.childrenResolved && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index ea46dd7282401..3bbe41cf8f15e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2a3e07aebe709..46d1aac1857d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -236,7 +236,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper /** * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case f @ Filter(_, a: Aggregate) => rewriteSubQueries(f, Seq(a, a.child)) // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6bdcf490ca5c8..b5002c7ee5eaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -46,37 +46,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Returns true if this subtree contains any streaming data sources. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) - /** - * Returns a copy of this node where `rule` has been recursively applied first to all of its - * children and then itself (post-order). When `rule` does not apply to a given node, it is left - * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already - * been marked as analyzed. - * - * @param rule the function use to transform this nodes children - */ - def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - if (!analyzed) { - val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) - } - } - } else { - this - } - } - /** * Recursively transforms the expressions of a tree, skipping nodes that have already * been analyzed. */ def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { - this resolveOperators { + this transformUp { case p => p.transformExpressions(r) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index cc86f1f6e2f48..eb5f7d53e2aad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly - * skips sub-trees that have already been marked as analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp` plus analysis barrier and make sure + * it can correctly skips sub-trees that have already been marked as analyzed. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -36,37 +36,35 @@ class LogicalPlanSuite extends SparkFunSuite { private val testRelation = LocalRelation() - test("resolveOperator runs on operators") { + test("transformUp runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) - plan resolveOperators function + plan transformUp function assert(invocationCount === 1) } - test("resolveOperator runs on operators recursively") { + test("transformUp runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) - plan resolveOperators function + plan transformUp function assert(invocationCount === 2) } - test("resolveOperator skips all ready resolved plans") { + test("transformUp skips all ready resolved plans wrapped in analysis barrier") { invocationCount = 0 - val plan = Project(Nil, Project(Nil, testRelation)) - plan.foreach(_.setAnalyzed()) - plan resolveOperators function + val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) + plan transformUp function assert(invocationCount === 0) } - test("resolveOperator skips partially resolved plans") { + test("transformUp skips partially resolved plans wrapped in analysis barrier") { invocationCount = 0 - val plan1 = Project(Nil, testRelation) + val plan1 = AnalysisBarrier(Project(Nil, testRelation)) val plan2 = Project(Nil, plan1) - plan1.foreach(_.setAnalyzed()) - plan2 resolveOperators function + plan2 transformUp function assert(invocationCount === 1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0628b37a9675d..22349818a0e63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -178,7 +178,7 @@ class Dataset[T] private[sql]( @transient private[sql] val logicalPlan: LogicalPlan = { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. - queryExecution.analyzed match { + val analyzed = queryExecution.analyzed match { case c: Command => LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => @@ -186,6 +186,8 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } + // Wrap analyzed logical plan with an analysis barrier so we won't traverse/resolve it again. + AnalysisBarrier(analyzed) } /** @@ -356,10 +358,7 @@ class Dataset[T] private[sql]( */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = { - val plan = AnalysisBarrier(logicalPlan) - new Dataset[Row](sparkSession, plan, RowEncoder(schema)) - } + def toDF(): DataFrame = new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema)) /** * :: Experimental :: @@ -474,7 +473,7 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + def isLocal: Boolean = logicalPlan.asInstanceOf[AnalysisBarrier].child.isInstanceOf[LocalRelation] /** * Returns true if this Dataset contains one or more sources that continuously @@ -705,7 +704,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(AnalysisBarrier(logicalPlan), right.logicalPlan, joinType = Inner, None) + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } /** @@ -788,8 +787,8 @@ class Dataset[T] private[sql]( withPlan { Join( - AnalysisBarrier(joined.left), - AnalysisBarrier(joined.right), + joined.left, + joined.right, UsingJoin(JoinType(joinType), usingColumns), None) } @@ -844,9 +843,8 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(AnalysisBarrier(logicalPlan), right.logicalPlan, JoinType(joinType), - Some(joinExprs.expr))) - .queryExecution.analyzed.asInstanceOf[Join] + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) { @@ -854,8 +852,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(AnalysisBarrier(this.logicalPlan)).queryExecution.analyzed - val ranalyzed = withPlan(AnalysisBarrier(right.logicalPlan)).queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -887,7 +885,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(AnalysisBarrier(logicalPlan), right.logicalPlan, joinType = Cross, None) + Join(logicalPlan, right.logicalPlan, joinType = Cross, None) } /** @@ -1138,7 +1136,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), AnalysisBarrier(logicalPlan)) + Project(cols.map(_.named), logicalPlan) } /** @@ -1816,7 +1814,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, AnalysisBarrier(logicalPlan)) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -1857,7 +1855,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, AnalysisBarrier(logicalPlan)) + qualifier = None, generatorOutput = Nil, logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3f4a78580f1eb..5f65898f5312e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -38,7 +38,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 09a5eda6e543f..b40af76755988 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -88,7 +88,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -115,7 +115,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -159,7 +159,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case InsertIntoTable(relation: CatalogRelation, partSpec, query, overwrite, ifNotExists) if DDLUtils.isHiveTable(relation.tableMeta) => InsertIntoHiveTable(relation.tableMeta, partSpec, query, overwrite, ifNotExists) From a855182d8f5037daab718820775cbcf8add01546 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 5 May 2017 02:44:44 +0000 Subject: [PATCH 06/21] Solving merging issue. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e799ae458cadc..0e4323da07684 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1006,7 +1006,7 @@ class Analyzer( */ object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedAttribute]) => From 4ff9610133fca947fab23af6ea67e6c7af50e8d2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 5 May 2017 03:37:36 +0000 Subject: [PATCH 07/21] Do not change exposed logicalPlan. --- .../scala/org/apache/spark/sql/Dataset.scala | 95 ++++++++++--------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d41c5c9b6ca06..13e087414dfb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -178,7 +178,7 @@ class Dataset[T] private[sql]( @transient private[sql] val logicalPlan: LogicalPlan = { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. - val analyzed = queryExecution.analyzed match { + queryExecution.analyzed match { case c: Command => LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => @@ -186,10 +186,11 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } - // Wrap analyzed logical plan with an analysis barrier so we won't traverse/resolve it again. - AnalysisBarrier(analyzed) } + // Wrap analyzed logical plan with an analysis barrier so we won't traverse/resolve it again. + @transient private val planBarrier: AnalysisBarrier = AnalysisBarrier(logicalPlan) + /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -415,7 +416,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planBarrier) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -509,7 +510,7 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - def isLocal: Boolean = logicalPlan.asInstanceOf[AnalysisBarrier].child.isInstanceOf[LocalRelation] + def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** * Returns true if this Dataset contains one or more sources that continuously @@ -617,7 +618,7 @@ class Dataset[T] private[sql]( .getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'")) require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planBarrier) } /** @@ -791,7 +792,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + Join(planBarrier, right.planBarrier, joinType = Inner, None) } /** @@ -869,7 +870,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + Join(planBarrier, right.planBarrier, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -930,7 +931,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + Join(planBarrier, right.planBarrier, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -939,8 +940,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + val lanalyzed = withPlan(this.planBarrier).queryExecution.analyzed + val ranalyzed = withPlan(right.planBarrier).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -972,7 +973,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + Join(planBarrier, right.planBarrier, joinType = Cross, None) } /** @@ -1004,8 +1005,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.logicalPlan, - other.logicalPlan, + this.planBarrier, + other.planBarrier, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1175,7 +1176,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { - Hint(name, parameters, logicalPlan) + Hint(name, parameters, planBarrier) } /** @@ -1201,7 +1202,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + SubqueryAlias(alias, planBarrier) } /** @@ -1239,7 +1240,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), logicalPlan) + Project(cols.map(_.named), planBarrier) } /** @@ -1294,8 +1295,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, - logicalPlan) + val project = Project(c1.withInputType(exprEnc, planBarrier.output).named :: Nil, + planBarrier) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1313,8 +1314,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, logicalPlan.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) + columns.map(_.withInputType(exprEnc, planBarrier.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, planBarrier)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1390,7 +1391,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + Filter(condition.expr, planBarrier) } /** @@ -1567,7 +1568,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = logicalPlan + val inputPlan = planBarrier val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1713,7 +1714,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + Limit(Literal(n), planBarrier) } /** @@ -1742,7 +1743,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) + CombineUnions(Union(planBarrier, other.planBarrier)) } /** @@ -1756,7 +1757,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(logicalPlan, other.logicalPlan) + Intersect(planBarrier, other.planBarrier) } /** @@ -1770,7 +1771,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(logicalPlan, other.logicalPlan) + Except(planBarrier, other.planBarrier) } /** @@ -1791,7 +1792,7 @@ class Dataset[T] private[sql]( s"Fraction must be nonnegative, but got ${fraction}") withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + Sample(0.0, fraction, withReplacement, seed, planBarrier)() } } @@ -1833,15 +1834,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = logicalPlan.output + val sortOrder = planBarrier.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, logicalPlan) + Sort(sortOrder, global = false, planBarrier) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - logicalPlan + planBarrier } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -1925,7 +1926,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planBarrier) } } @@ -1966,7 +1967,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planBarrier) } } @@ -2129,7 +2130,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, logicalPlan, isStreaming) + Deduplicate(groupCols, planBarrier, isStreaming) } /** @@ -2278,7 +2279,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planBarrier)) } /** @@ -2292,7 +2293,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planBarrier)) } /** @@ -2306,7 +2307,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + MapElements[T, U](func, planBarrier) } /** @@ -2321,7 +2322,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, logicalPlan)) + withTypedPlan(MapElements[T, U](func, planBarrier)) } /** @@ -2337,7 +2338,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, logicalPlan), + MapPartitions[T, U](func, planBarrier), implicitly[Encoder[U]]) } @@ -2368,7 +2369,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planBarrier)) } /** @@ -2523,7 +2524,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + Repartition(numPartitions, shuffle = true, planBarrier) } /** @@ -2537,7 +2538,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), planBarrier, numPartitions) } /** @@ -2553,7 +2554,7 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression( - partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) + partitionExprs.map(_.expr), planBarrier, sparkSession.sessionState.conf.numShufflePartitions) } /** @@ -2574,7 +2575,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + Repartition(numPartitions, shuffle = false, planBarrier) } /** @@ -2663,7 +2664,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserialized = CatalystSerde.deserialize[T](planBarrier) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } @@ -2762,7 +2763,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = logicalPlan, + child = planBarrier, allowExisting = false, replace = replace, viewType = viewType) @@ -2935,7 +2936,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, logicalPlan) + Sort(sortOrder, global = global, planBarrier) } } From d0a94f417bbe22f081772b2518315b367093b81d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 6 May 2017 08:13:37 +0000 Subject: [PATCH 08/21] Fix test. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++++++ .../src/main/scala/org/apache/spark/sql/Dataset.scala | 4 ++-- .../org/apache/spark/sql/execution/PlannerSuite.scala | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0e4323da07684..1fc3ef162a3c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1032,6 +1032,7 @@ class Analyzer( object ResolveMissingReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa case s @ Sort(order, _, child) if child.resolved => @@ -1107,6 +1108,8 @@ class Analyzer( throw new AnalysisException(s"Can't add $missingAttrs to $d") case u: UnaryNode => u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) + case AnalysisBarrier(subPlan) => + AnalysisBarrier(addMissingAttr(subPlan, missingAttrs)) case other => throw new AnalysisException(s"Can't add $missingAttrs to $other") } @@ -1125,6 +1128,7 @@ class Analyzer( plan match { case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => resolveExpressionRecursively(resolved, u.child) + case AnalysisBarrier(subPlan) => resolveExpressionRecursively(resolved, subPlan) case other => resolved } } @@ -1535,6 +1539,8 @@ class Analyzer( */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => + apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier(_)) case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1594,6 +1600,8 @@ class Analyzer( case ae: AnalysisException => filter } + case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier(_)) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 13e087414dfb5..ab4ffe84f1ef6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -189,7 +189,7 @@ class Dataset[T] private[sql]( } // Wrap analyzed logical plan with an analysis barrier so we won't traverse/resolve it again. - @transient private val planBarrier: AnalysisBarrier = AnalysisBarrier(logicalPlan) + @transient private val planBarrier: LogicalPlan = AnalysisBarrier(logicalPlan) /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the @@ -1743,7 +1743,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(planBarrier, other.planBarrier)) + CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier(_)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 4d155d538d637..d02c8ffe33f0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -241,7 +241,7 @@ class PlannerSuite extends SharedSQLContext { test("collapse adjacent repartitions") { val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length - assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.analyzed) === 3) assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) doubleRepartitioned.queryExecution.optimizedPlan match { case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => From 02e11f9ead8ce3b6cafbcd59c042d12daaabe78a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 9 May 2017 09:48:28 +0000 Subject: [PATCH 09/21] Address comments. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../scala/org/apache/spark/sql/Dataset.scala | 87 ++++++++++--------- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1fc3ef162a3c1..e881b5117cff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -716,7 +716,7 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - right + oriRight case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { @@ -729,7 +729,7 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - newRight + AnalysisBarrier(newRight) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ab4ffe84f1ef6..c1193c1c935ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -189,7 +189,7 @@ class Dataset[T] private[sql]( } // Wrap analyzed logical plan with an analysis barrier so we won't traverse/resolve it again. - @transient private val planBarrier: LogicalPlan = AnalysisBarrier(logicalPlan) + @transient private val planWithBarrier: LogicalPlan = AnalysisBarrier(logicalPlan) /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the @@ -416,7 +416,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planBarrier) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -618,7 +618,7 @@ class Dataset[T] private[sql]( .getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'")) require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planBarrier) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier) } /** @@ -792,7 +792,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(planBarrier, right.planBarrier, joinType = Inner, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) } /** @@ -870,7 +870,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(planBarrier, right.planBarrier, joinType = JoinType(joinType), None)) + Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -931,7 +931,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(planBarrier, right.planBarrier, JoinType(joinType), Some(joinExprs.expr))) + Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -940,8 +940,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.planBarrier).queryExecution.analyzed - val ranalyzed = withPlan(right.planBarrier).queryExecution.analyzed + val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed + val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -973,7 +973,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(planBarrier, right.planBarrier, joinType = Cross, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) } /** @@ -1005,8 +1005,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.planBarrier, - other.planBarrier, + this.planWithBarrier, + other.planWithBarrier, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1176,7 +1176,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { - Hint(name, parameters, planBarrier) + Hint(name, parameters, planWithBarrier) } /** @@ -1202,7 +1202,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, planBarrier) + SubqueryAlias(alias, planWithBarrier) } /** @@ -1240,7 +1240,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), planBarrier) + Project(cols.map(_.named), planWithBarrier) } /** @@ -1295,8 +1295,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, planBarrier.output).named :: Nil, - planBarrier) + val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, + planWithBarrier) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1314,8 +1314,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, planBarrier.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, planBarrier)) + columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1391,7 +1391,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, planBarrier) + Filter(condition.expr, planWithBarrier) } /** @@ -1568,7 +1568,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = planBarrier + val inputPlan = planWithBarrier val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1714,7 +1714,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), planBarrier) + Limit(Literal(n), planWithBarrier) } /** @@ -1757,7 +1757,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(planBarrier, other.planBarrier) + Intersect(planWithBarrier, other.planWithBarrier) } /** @@ -1771,7 +1771,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planBarrier, other.planBarrier) + Except(planWithBarrier, other.planWithBarrier) } /** @@ -1792,7 +1792,7 @@ class Dataset[T] private[sql]( s"Fraction must be nonnegative, but got ${fraction}") withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, planBarrier)() + Sample(0.0, fraction, withReplacement, seed, planWithBarrier)() } } @@ -1834,15 +1834,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = planBarrier.output + val sortOrder = planWithBarrier.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, planBarrier) + Sort(sortOrder, global = false, planWithBarrier) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - planBarrier + planWithBarrier } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -1926,7 +1926,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, planBarrier) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -1967,7 +1967,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, planBarrier) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -2130,7 +2130,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, planBarrier, isStreaming) + Deduplicate(groupCols, planWithBarrier, isStreaming) } /** @@ -2279,7 +2279,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, planBarrier)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2293,7 +2293,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, planBarrier)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2307,7 +2307,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, planBarrier) + MapElements[T, U](func, planWithBarrier) } /** @@ -2322,7 +2322,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, planBarrier)) + withTypedPlan(MapElements[T, U](func, planWithBarrier)) } /** @@ -2338,7 +2338,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, planBarrier), + MapPartitions[T, U](func, planWithBarrier), implicitly[Encoder[U]]) } @@ -2369,7 +2369,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planBarrier)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) } /** @@ -2524,7 +2524,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, planBarrier) + Repartition(numPartitions, shuffle = true, planWithBarrier) } /** @@ -2538,7 +2538,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), planBarrier, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) } /** @@ -2554,7 +2554,8 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression( - partitionExprs.map(_.expr), planBarrier, sparkSession.sessionState.conf.numShufflePartitions) + partitionExprs.map(_.expr), planWithBarrier, + sparkSession.sessionState.conf.numShufflePartitions) } /** @@ -2575,7 +2576,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, planBarrier) + Repartition(numPartitions, shuffle = false, planWithBarrier) } /** @@ -2664,7 +2665,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](planBarrier) + val deserialized = CatalystSerde.deserialize[T](planWithBarrier) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } @@ -2763,7 +2764,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = planBarrier, + child = planWithBarrier, allowExisting = false, replace = replace, viewType = viewType) @@ -2936,7 +2937,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, planBarrier) + Sort(sortOrder, global = global, planWithBarrier) } } From c313e353104fc93ba72a2152a7044a6ea8c06311 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 10 May 2017 06:08:22 +0000 Subject: [PATCH 10/21] Correctly set isStreaming for barrier. --- .../spark/sql/catalyst/plans/logical/basicLogicalOperators.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 76a5d24a89f05..8c99e82f72b38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -912,4 +912,5 @@ case class Deduplicate( case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { override def output: Seq[Attribute] = child.output override def analyzed: Boolean = true + override def isStreaming: Boolean = child.isStreaming } From 7e9dfac854a2a698773b267a5dd31b77fbd4ab30 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 11 May 2017 07:14:06 +0000 Subject: [PATCH 11/21] Address comments. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 15 +++++++-------- .../sql/catalyst/analysis/TypeCoercion.scala | 18 +++++++++--------- .../catalyst/analysis/timeZoneAnalysis.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 10 ---------- .../sql/catalyst/analysis/AnalysisSuite.scala | 3 ++- 5 files changed, 19 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f6c5d0adb712e..58929295d75fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1035,7 +1035,8 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if child.resolved => + case s @ Sort(order, _, orgChild) if orgChild.resolved => + val child = CleanupBarriers(orgChild) try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -1043,7 +1044,7 @@ class Analyzer( if (missingAttrs.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(child.output, - Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) + Sort(newOrder, s.global, AnalysisBarrier(addMissingAttr(child, missingAttrs)))) } else if (newOrder != order) { s.copy(order = newOrder) } else { @@ -1056,7 +1057,8 @@ class Analyzer( case ae: AnalysisException => s } - case f @ Filter(cond, child) if child.resolved => + case f @ Filter(cond, orgChild) if orgChild.resolved => + val child = CleanupBarriers(orgChild) try { val newCond = resolveExpressionRecursively(cond, child) val requiredAttrs = newCond.references.filter(_.resolved) @@ -1064,7 +1066,7 @@ class Analyzer( if (missingAttrs.nonEmpty) { // Add missing attributes and then project them away. Project(child.output, - Filter(newCond, addMissingAttr(child, missingAttrs))) + Filter(newCond, AnalysisBarrier(addMissingAttr(child, missingAttrs)))) } else if (newCond != cond) { f.copy(condition = newCond) } else { @@ -1108,8 +1110,6 @@ class Analyzer( throw new AnalysisException(s"Can't add $missingAttrs to $d") case u: UnaryNode => u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) - case AnalysisBarrier(subPlan) => - AnalysisBarrier(addMissingAttr(subPlan, missingAttrs)) case other => throw new AnalysisException(s"Can't add $missingAttrs to $other") } @@ -1128,7 +1128,6 @@ class Analyzer( plan match { case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => resolveExpressionRecursively(resolved, u.child) - case AnalysisBarrier(subPlan) => resolveExpressionRecursively(resolved, subPlan) case other => resolved } } @@ -2469,7 +2468,7 @@ object CleanupAliases extends Rule[LogicalPlan] { /** Remove the barrier nodes of analysis */ object CleanupBarriers extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case AnalysisBarrier(child) => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1cc7f32666691..c3645170589c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -335,7 +335,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -391,7 +391,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -449,7 +449,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -490,7 +490,7 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -580,7 +580,7 @@ object TypeCoercion { * converted to fractional types. */ object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -602,7 +602,7 @@ object TypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => @@ -632,7 +632,7 @@ object TypeCoercion { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => @@ -656,7 +656,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -673,7 +673,7 @@ object TypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index a27aa845bf0ae..af1f9165b0044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.resolveExpressions(transformTimeZoneExprs) + plan.transformAllExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b5002c7ee5eaf..3c9fbf8710f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -46,16 +46,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Returns true if this subtree contains any streaming data sources. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) - /** - * Recursively transforms the expressions of a tree, skipping nodes that have already - * been analyzed. - */ - def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { - this transformUp { - case p => p.transformExpressions(r) - } - } - /** A cache for the estimated statistics, such that it will only be computed once. */ private var statsCache: Option[Statistics] = None diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1e7f87882061a..fabe16d99fea3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -449,7 +449,8 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { AnalysisBarrier(SubqueryAlias("tbl", testRelation))), Project(testRelation.output, SubqueryAlias("tbl", testRelation))) - // Make sure we won't resolve the plans wrapped in an [[AnalysisBarrier]] + // Verify we won't go through a plan wrapped in a barrier. + // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), SubqueryAlias("tbl", testRelation))) assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) From b9d03cd8ea97e14ebaef2c2c2b6886e183f3fba8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 17 May 2017 07:47:45 +0000 Subject: [PATCH 12/21] Fix test. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0d321cdc4e162..25dbc15a2e2d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1058,7 +1058,7 @@ class Analyzer( if (missingAttrs.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(child.output, - Sort(newOrder, s.global, AnalysisBarrier(addMissingAttr(child, missingAttrs)))) + Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) } else if (newOrder != order) { s.copy(order = newOrder) } else { @@ -1080,7 +1080,7 @@ class Analyzer( if (missingAttrs.nonEmpty) { // Add missing attributes and then project them away. Project(child.output, - Filter(newCond, AnalysisBarrier(addMissingAttr(child, missingAttrs)))) + Filter(newCond, addMissingAttr(child, missingAttrs))) } else if (newCond != cond) { f.copy(condition = newCond) } else { @@ -1099,7 +1099,7 @@ class Analyzer( */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { - return plan + return AnalysisBarrier(plan) } plan match { case p: Project => From 6a7204c0fc00dbe2e43d6d65e722b3b13c3b35d0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 May 2017 05:23:38 +0000 Subject: [PATCH 13/21] Address comments. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 +++++++------- .../sql/catalyst/plans/LogicalPlanSuite.scala | 2 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 25dbc15a2e2d2..39055d4006c86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -167,7 +167,7 @@ class Analyzer( UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases, - CleanupBarriers) + EliminateBarriers) ) /** @@ -673,7 +673,7 @@ class Analyzer( */ private def dedupRight (left: LogicalPlan, oriRight: LogicalPlan): LogicalPlan = { // Remove analysis barrier if any. - val right = CleanupBarriers(oriRight) + val right = EliminateBarriers(oriRight) val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") @@ -1050,7 +1050,7 @@ class Analyzer( case sa @ Sort(_, _, child: Aggregate) => sa case s @ Sort(order, _, orgChild) if !s.resolved && orgChild.resolved => - val child = CleanupBarriers(orgChild) + val child = EliminateBarriers(orgChild) try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -1072,7 +1072,7 @@ class Analyzer( } case f @ Filter(cond, orgChild) if !f.resolved && orgChild.resolved => - val child = CleanupBarriers(orgChild) + val child = EliminateBarriers(orgChild) try { val newCond = resolveExpressionRecursively(cond, child) val requiredAttrs = newCond.references.filter(_.resolved) @@ -1553,7 +1553,7 @@ class Analyzer( object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => - apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier(_)) + apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1614,7 +1614,7 @@ class Analyzer( } case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => - apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier(_)) + apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -2481,7 +2481,7 @@ object CleanupAliases extends Rule[LogicalPlan] { } /** Remove the barrier nodes of analysis */ -object CleanupBarriers extends Rule[LogicalPlan] { +object EliminateBarriers extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case AnalysisBarrier(child) => child } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index eb5f7d53e2aad..215db848383eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.IntegerType /** * This suite is used to test [[LogicalPlan]]'s `transformUp` plus analysis barrier and make sure - * it can correctly skips sub-trees that have already been marked as analyzed. + * it can correctly skip sub-trees that have already been marked as analyzed. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5dd45e07aa991..e5a8763e3ca81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -187,7 +187,7 @@ class Dataset[T] private[sql]( } } - // Wrap analyzed logical plan with an analysis barrier so we won't traverse/resolve it again. + // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. @transient private val planWithBarrier: LogicalPlan = AnalysisBarrier(logicalPlan) /** @@ -1744,7 +1744,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier(_)) + CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) } /** From 3437ae01a2db9575f49f1ed56e3f0d8990b32243 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 22 May 2017 07:35:11 +0000 Subject: [PATCH 14/21] Wrap AnalysisBarrier on df.logicalPlan. --- .../plans/logical/basicLogicalOperators.scala | 7 ++ .../scala/org/apache/spark/sql/Dataset.scala | 95 +++++++++---------- 2 files changed, 54 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ba97beb724c36..14111f804972f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -919,4 +919,11 @@ case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { override def output: Seq[Attribute] = child.output override def analyzed: Boolean = true override def isStreaming: Boolean = child.isStreaming + override lazy val canonicalized: LogicalPlan = child.canonicalized + + override def find(f: LogicalPlan => Boolean): Option[LogicalPlan] = if (f(this)) { + Some(this) + } else { + child.find(f) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e5a8763e3ca81..93aae3928f30a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -177,7 +177,7 @@ class Dataset[T] private[sql]( @transient private[sql] val logicalPlan: LogicalPlan = { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. - queryExecution.analyzed match { + val analyzed = queryExecution.analyzed match { case c: Command => LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => @@ -185,11 +185,10 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } + // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. + AnalysisBarrier(analyzed) } - // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private val planWithBarrier: LogicalPlan = AnalysisBarrier(logicalPlan) - /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -416,7 +415,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -619,7 +618,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) } /** @@ -793,7 +792,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } /** @@ -871,7 +870,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -932,7 +931,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -941,8 +940,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed - val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -974,7 +973,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) + Join(logicalPlan, right.logicalPlan, joinType = Cross, None) } /** @@ -1006,8 +1005,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.planWithBarrier, - other.planWithBarrier, + this.logicalPlan, + other.logicalPlan, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1177,7 +1176,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { - Hint(name, parameters, planWithBarrier) + Hint(name, parameters, logicalPlan) } /** @@ -1203,7 +1202,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, planWithBarrier) + SubqueryAlias(alias, logicalPlan) } /** @@ -1241,7 +1240,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), planWithBarrier) + Project(cols.map(_.named), logicalPlan) } /** @@ -1296,8 +1295,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, - planWithBarrier) + val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, + logicalPlan) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1315,8 +1314,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) + columns.map(_.withInputType(exprEnc, logicalPlan.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1392,7 +1391,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, planWithBarrier) + Filter(condition.expr, logicalPlan) } /** @@ -1569,7 +1568,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = planWithBarrier + val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1715,7 +1714,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), planWithBarrier) + Limit(Literal(n), logicalPlan) } /** @@ -1744,7 +1743,8 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) + CombineUnions(Union(EliminateBarriers(logicalPlan), EliminateBarriers(other.logicalPlan))) + .mapChildren(AnalysisBarrier) } /** @@ -1758,7 +1758,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(planWithBarrier, other.planWithBarrier) + Intersect(logicalPlan, other.logicalPlan) } /** @@ -1772,7 +1772,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier) + Except(logicalPlan, other.logicalPlan) } /** @@ -1793,7 +1793,7 @@ class Dataset[T] private[sql]( s"Fraction must be nonnegative, but got ${fraction}") withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, planWithBarrier)() + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() } } @@ -1835,15 +1835,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = planWithBarrier.output + val sortOrder = logicalPlan.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, planWithBarrier) + Sort(sortOrder, global = false, logicalPlan) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - planWithBarrier + logicalPlan } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -1927,7 +1927,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -1968,7 +1968,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -2131,7 +2131,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, planWithBarrier, isStreaming) + Deduplicate(groupCols, logicalPlan, isStreaming) } /** @@ -2280,7 +2280,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2294,7 +2294,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2308,7 +2308,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, planWithBarrier) + MapElements[T, U](func, logicalPlan) } /** @@ -2323,7 +2323,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, planWithBarrier)) + withTypedPlan(MapElements[T, U](func, logicalPlan)) } /** @@ -2339,7 +2339,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, planWithBarrier), + MapPartitions[T, U](func, logicalPlan), implicitly[Encoder[U]]) } @@ -2370,7 +2370,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) } /** @@ -2525,7 +2525,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, planWithBarrier) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -2539,7 +2539,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) } /** @@ -2555,8 +2555,7 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression( - partitionExprs.map(_.expr), planWithBarrier, - sparkSession.sessionState.conf.numShufflePartitions) + partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) } /** @@ -2577,7 +2576,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, planWithBarrier) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** @@ -2666,7 +2665,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](planWithBarrier) + val deserialized = CatalystSerde.deserialize[T](logicalPlan) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } @@ -2765,7 +2764,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = planWithBarrier, + child = logicalPlan, allowExisting = false, replace = replace, viewType = viewType) @@ -2936,7 +2935,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, planWithBarrier) + Sort(sortOrder, global = global, logicalPlan) } } From 555fa8e6e19fee63efd3fc6795c1f9bd6ca8c6a6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 23 May 2017 10:44:23 +0000 Subject: [PATCH 15/21] Fix test. --- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 9 +++++---- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 4 ++-- .../org/apache/spark/sql/execution/datasources/ddl.scala | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 1732a8e08b73f..408a50bf580c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -232,7 +232,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { runCommand(df.sparkSession, "save") { SaveIntoDataSourceCommand( - query = df.logicalPlan, + query = df.queryExecution.analyzed, provider = source, partitionColumns = partitioningColumns.getOrElse(Nil), options = extraOptions.toMap, @@ -284,7 +284,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { InsertIntoTable( table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], - query = df.logicalPlan, + query = df.queryExecution.analyzed, overwrite = mode == SaveMode.Overwrite, ifNotExists = false) } @@ -370,7 +370,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { case (true, SaveMode.Overwrite) => // Get all input data source or hive relations of the query. - val srcRelations = df.logicalPlan.collect { + val srcRelations = df.queryExecution.analyzed.collect { case LogicalRelation(src: BaseRelation, _, _) => src case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) => relation.tableMeta.identifier @@ -417,7 +417,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) + runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, + Some(df.queryExecution.analyzed))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 93aae3928f30a..f91ab1d4c8b4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1743,7 +1743,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(EliminateBarriers(logicalPlan), EliminateBarriers(other.logicalPlan))) + CombineUnions(Union(queryExecution.analyzed, other.queryExecution.analyzed)) .mapChildren(AnalysisBarrier) } @@ -2764,7 +2764,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = logicalPlan, + child = queryExecution.analyzed, allowExisting = false, replace = replace, viewType = viewType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index f8d4a9bb5b81a..ade43a2863fb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -90,7 +90,7 @@ case class CreateTempViewUsing( val catalog = sparkSession.sessionState.catalog val viewDefinition = Dataset.ofRows( - sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan + sparkSession, LogicalRelation(dataSource.resolveRelation())).queryExecution.analyzed if (global) { catalog.createGlobalTempView(tableIdent.table, viewDefinition, replace) From f3e4208eb23bee5cfc0e8a33134d58fac5526dbb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 May 2017 02:41:57 +0000 Subject: [PATCH 16/21] fix test. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/streaming/ProgressReporter.scala | 2 +- .../org/apache/spark/sql/hive/execution/HiveUDFSuite.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f91ab1d4c8b4d..a59bc55d71b33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2764,7 +2764,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = queryExecution.analyzed, + child = logicalPlan, allowExisting = false, replace = replace, viewType = viewType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index a4e4ca821374c..17040406c8cd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -234,7 +234,7 @@ trait ProgressReporter extends Logging { // 3. For each source, we sum the metrics of the associated execution plan leaves. // val logicalPlanLeafToSource = newData.flatMap { case (source, df) => - df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + df.queryExecution.analyzed.collectLeaves().map { leaf => leaf -> source } } val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 4446af2e75e00..4f91fc871d63b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -393,12 +393,12 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { // HiveSimpleUDF sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") val df1 = sql("SELECT testStringStringUDF(rand(), \"hello\")") - assert(!df1.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) + assert(!df1.queryExecution.analyzed.asInstanceOf[Project].projectList.forall(_.deterministic)) // HiveGenericUDF sql(s"CREATE TEMPORARY FUNCTION testGenericUDFHash AS '${classOf[GenericUDFHash].getName}'") val df2 = sql("SELECT testGenericUDFHash(rand())") - assert(!df2.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) + assert(!df2.queryExecution.analyzed.asInstanceOf[Project].projectList.forall(_.deterministic)) } } From c0bee014eaa268014f5e156498d8cc7d90533ac7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 May 2017 03:25:34 +0000 Subject: [PATCH 17/21] Avoid overriding find in AnalysisBarrier. --- .../sql/catalyst/plans/logical/basicLogicalOperators.scala | 6 ------ .../scala/org/apache/spark/sql/execution/CacheManager.scala | 6 +++--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ba3021d4fe3cc..d1f42c1caeb60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -923,10 +923,4 @@ case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { override def analyzed: Boolean = true override def isStreaming: Boolean = child.isStreaming override lazy val canonicalized: LogicalPlan = child.canonicalized - - override def find(f: LogicalPlan => Boolean): Option[LogicalPlan] = if (f(this)) { - Some(this) - } else { - child.find(f) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 0ea806d6cb50b..1e89a8f8a8dc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -90,7 +90,7 @@ class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.logicalPlan + val planToCache = query.queryExecution.analyzed if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { @@ -110,7 +110,7 @@ class CacheManager extends Logging { * Un-cache all the cache entries that refer to the given plan. */ def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + uncacheQuery(query.sparkSession, query.queryExecution.analyzed, blocking) } /** @@ -159,7 +159,7 @@ class CacheManager extends Logging { /** Optionally returns cached data for the given [[Dataset]] */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.logicalPlan) + lookupCachedData(query.queryExecution.analyzed) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ From 1c1cc9d1597d16deab14afb3f7001b13bc705321 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 May 2017 05:23:51 +0000 Subject: [PATCH 18/21] Fix test. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 15 ++++++++------- .../execution/command/AnalyzeColumnCommand.scala | 2 +- .../sql/execution/datasources/DataSource.scala | 8 ++++---- .../sql/execution/streaming/FileStreamSink.scala | 2 +- .../sql/execution/streaming/ForeachSink.scala | 2 +- .../sql/execution/streaming/StreamExecution.scala | 2 +- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../spark/sql/streaming/DataStreamWriter.scala | 2 +- .../org/apache/spark/sql/CachedTableSuite.scala | 2 +- .../scala/org/apache/spark/sql/DatasetSuite.scala | 6 +++--- .../scala/org/apache/spark/sql/QueryTest.scala | 2 +- .../apache/spark/sql/execution/PlannerSuite.scala | 6 +++--- .../columnar/InMemoryColumnarQuerySuite.scala | 11 ++++++----- .../sql/execution/joins/ExistenceJoinSuite.scala | 3 ++- .../sql/execution/joins/InnerJoinSuite.scala | 3 ++- .../sql/execution/joins/OuterJoinSuite.scala | 3 ++- .../sql/streaming/EventTimeWatermarkSuite.scala | 2 +- .../streaming/FlatMapGroupsWithStateSuite.scala | 2 +- .../apache/spark/sql/streaming/StreamTest.scala | 2 +- .../spark/sql/streaming/StreamingQuerySuite.scala | 2 +- .../apache/spark/sql/hive/ListTablesSuite.scala | 2 +- .../hive/execution/AggregationQuerySuite.scala | 2 +- 23 files changed, 45 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a59bc55d71b33..ef739cd4581cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -174,7 +174,7 @@ class Dataset[T] private[sql]( this(sqlContext.sparkSession, logicalPlan, encoder) } - @transient private[sql] val logicalPlan: LogicalPlan = { + @transient private val logicalPlan: LogicalPlan = { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. val analyzed = queryExecution.analyzed match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 64755434784a0..a378914ca304e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -58,21 +58,22 @@ class RelationalGroupedDataset protected[sql]( } val aliasedAgg = aggregates.map(alias) + val logicalPlan = df.queryExecution.analyzed groupType match { case RelationalGroupedDataset.GroupByType => Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, logicalPlan)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, logicalPlan)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, logicalPlan)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, logicalPlan)) } } @@ -223,7 +224,7 @@ class RelationalGroupedDataset protected[sql]( def agg(expr: Column, exprs: Column*): DataFrame = { toDF((expr +: exprs).map { case typed: TypedColumn[_, _] => - typed.withInputType(df.exprEnc, df.logicalPlan.output).expr + typed.withInputType(df.exprEnc, df.queryExecution.analyzed.output).expr case c => c.expr }) } @@ -428,8 +429,8 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.deserializer, df.exprEnc.schema, groupingAttributes, - df.logicalPlan.output, - df.logicalPlan)) + df.queryExecution.analyzed.output, + df.queryExecution.analyzed)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 0d8db2ff5d5a0..a47d58aa8c4f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -70,7 +70,7 @@ case class AnalyzeColumnCommand( tableIdent: TableIdentifier, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { - val relation = sparkSession.table(tableIdent).logicalPlan + val relation = sparkSession.table(tableIdent).queryExecution.analyzed // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = columnNames.map { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 14c40605ea31c..ac0b590e0eee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -412,7 +412,7 @@ case class DataSource( // not need to have the query as child, to avoid to analyze an optimized query, // because InsertIntoHadoopFsRelationCommand will be optimized first. val partitionAttributes = partitionColumns.map { name => - val plan = data.logicalPlan + val plan = data.queryExecution.analyzed plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") @@ -424,8 +424,8 @@ case class DataSource( }.head } // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This - // will be adjusted within InsertIntoHadoopFsRelation. + // ordering of data.queryExecution.analyzed (partition columns are all moved after data column). + // This will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( outputPath = outputPath, @@ -435,7 +435,7 @@ case class DataSource( bucketSpec = bucketSpec, fileFormat = format, options = options, - query = data.logicalPlan, + query = data.queryExecution.analyzed, mode = mode, catalogTable = catalogTable, fileIndex = fileIndex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 6885d0bf67ccb..e83ed66aadd73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -115,7 +115,7 @@ class FileStreamSink( // the given columns names. val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col => val nameEquality = data.sparkSession.sessionState.conf.resolver - data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { + data.queryExecution.analyzed.output.find(f => nameEquality(f.name, col)).getOrElse { throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index de09fb568d2a6..019422d4ec940 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -44,7 +44,7 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria // // Hence, we need to manually convert internal rows to objects using encoder. val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, + data.queryExecution.analyzed.output, data.sparkSession.sessionState.analyzer) data.queryExecution.toRdd.foreachPartition { iter => if (writer.open(TaskContext.getPartitionId(), batchId)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index b6ddf7437ea13..3a6ad80376555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -617,7 +617,7 @@ class StreamExecution( val withNewSources = logicalPlan transform { case StreamingExecutionRelation(source, output) => newData.get(source).map { data => - val newPlan = data.logicalPlan + val newPlan = data.queryExecution.analyzed assert(output.size == newPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + s"${Utils.truncatedString(newPlan.output, ",")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5edf03666ac22..43b7e3e8ae233 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1019,7 +1019,7 @@ object functions { * @since 1.5.0 */ def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) + Dataset[T](df.sparkSession, BroadcastHint(df.queryExecution.analyzed))(df.exprEnc) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0d2611f9bbcce..14525249c2edc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -359,7 +359,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * need to care about case sensitivity afterwards. */ private def normalize(columnName: String, columnType: String): String = { - val validColumnNames = df.logicalPlan.output.map(_.name) + val validColumnNames = df.queryExecution.analyzed.output.map(_.name) validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName)) .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + s"existing columns (${validColumnNames.mkString(", ")})")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 4114f7a19c7ba..be493baf4fc0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -189,7 +189,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } test("SPARK-1669: cacheTable should be idempotent") { - assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!spark.table("testData").queryExecution.analyzed.isInstanceOf[InMemoryRelation]) spark.catalog.cacheTable("testData") assertCached(spark.table("testData")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8eb381b91f46d..5b7624511c008 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1043,10 +1043,10 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) val cp = ds.checkpoint(eager) - val logicalRDD = cp.logicalPlan match { + val logicalRDD = cp.queryExecution.analyzed match { case plan: LogicalRDD => plan case _ => - val treeString = cp.logicalPlan.treeString(verbose = true) + val treeString = cp.queryExecution.analyzed.treeString(verbose = true) fail(s"Expecting a LogicalRDD, but got\n$treeString") } @@ -1118,7 +1118,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // instead of Int for avoiding possible overflow. val ds = (0 to 10000).map( i => (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() - val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes + val sizeInBytes = ds.queryExecution.analyzed.stats(sqlConf).sizeInBytes // sizeInBytes is 2404280404, before the fix, it overflows to a negative number assert(sizeInBytes > 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f9808834df4a5..2eb907bbb95b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -253,7 +253,7 @@ object QueryTest { df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Option[String] = { - val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + val isSorted = df.queryExecution.analyzed.collect { case s: logical.Sort => s }.nonEmpty if (checkToRDD) { df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d02c8ffe33f0f..5c3921c7bb448 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -168,21 +168,21 @@ class PlannerSuite extends SharedSQLContext { val query = testData.select('key, 'value).sort('key).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) - assert(planned.output === testData.select('key, 'value).logicalPlan.output) + assert(planned.output === testData.select('key, 'value).queryExecution.analyzed.output) } test("terminal limit -> project -> sort should use TakeOrderedAndProject") { val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) - assert(planned.output === testData.select('value, 'key).logicalPlan.output) + assert(planned.output === testData.select('value, 'key).queryExecution.analyzed.output) } test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { val query = testData.select('value).limit(2) val planned = query.queryExecution.sparkPlan assert(planned.isInstanceOf[CollectLimitExec]) - assert(planned.output === testData.select('value).logicalPlan.output) + assert(planned.output === testData.select('value).queryExecution.analyzed.output) } test("TakeOrderedAndProject can appear in the middle of plans") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 109b1d9db60d2..9a627f3412303 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -38,7 +38,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { private def cachePrimitiveTest(data: DataFrame, dataType: String) { data.createOrReplaceTempView(s"testData$dataType") val storageLevel = MEMORY_ONLY - val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan + val plan = spark.sessionState.executePlan(data.queryExecution.analyzed).sparkPlan val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) @@ -114,7 +114,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("simple columnar query") { - val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.sessionState.executePlan(testData.queryExecution.analyzed).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -131,7 +131,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("projection") { - val plan = spark.sessionState.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan + val plan = spark.sessionState.executePlan(testData.select('value, 'key).queryExecution.analyzed) + .sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -147,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.sessionState.executePlan(testData.queryExecution.analyzed).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -321,7 +322,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() - val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan + val plan = spark.sessionState.executePlan(data.queryExecution.analyzed).sparkPlan val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) // Materialize the data. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 38377164c10e6..e7febd2e3a56b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -85,7 +85,8 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Row]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + val join = Join(leftRows.queryExecution.analyzed, rightRows.queryExecution.analyzed, Inner, + Some(condition)) ExtractEquiJoinKeys.unapply(join) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4408ece112258..dabb02a1e2d79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -80,7 +80,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + val join = Join(leftRows.queryExecution.analyzed, rightRows.queryExecution.analyzed, Inner, + Some(condition())) ExtractEquiJoinKeys.unapply(join) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 001feb0f2b399..bb3fb558d5485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -72,7 +72,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + val join = Join(leftRows.queryExecution.analyzed, rightRows.queryExecution.analyzed, Inner, + Some(condition)) ExtractEquiJoinKeys.unapply(join) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 1b60a06ec402f..8603d4855a13c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -338,7 +338,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin .withWatermark("first", "1 minute") .withWatermark("second", "2 minutes") - val eventTimeColumns = df.logicalPlan.output + val eventTimeColumns = df.queryExecution.analyzed.output .filter(_.metadata.contains(EventTimeWatermark.delayKey)) assert(eventTimeColumns.size === 1) assert(eventTimeColumns(0).name === "second") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 6bb9408ce99ed..1c17d692b36fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1034,7 +1034,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .toDS .groupByKey(x => x) .flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func) - .logicalPlan.collectFirst { + .queryExecution.analyzed.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( f, k, v, g, d, o, None, s, m, t, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 5bc36dd30f6d1..16e21e2e7f6b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -516,7 +516,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { queryToUse.flatMap { query => findSourceIndex(query.logicalPlan) }.orElse { - findSourceIndex(stream.logicalPlan) + findSourceIndex(stream.queryExecution.analyzed) }.getOrElse { throw new IllegalArgumentException( "Could find index of the source to which data was added") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index b69536ed37463..42665d9341fd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -441,7 +441,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi test("input row calculation with trigger input DF having multiple leaves") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) - require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) + require(streamingTriggerDF.queryExecution.analyzed.collectLeaves().size > 1) val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF) // After the first trigger, the calculated input rows should be 10 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 15ba61646d03f..2c0f7fe6d2d93 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -34,7 +34,7 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft super.beforeAll() // The catalog in HiveContext is a case insensitive one. sessionState.catalog.createTempView( - "ListTablesSuiteTable", df.logicalPlan, overrideIfExists = true) + "ListTablesSuiteTable", df.queryExecution.analyzed, overrideIfExists = true) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 84f915977bd88..73f6a3e6ac05b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1020,7 +1020,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. // todo: remove it? - val newActual = Dataset.ofRows(spark, actual.logicalPlan) + val newActual = Dataset.ofRows(spark, actual.queryExecution.analyzed) QueryTest.checkAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => From cba784b8d540358fb9ab2a60808842d70170d101 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 May 2017 06:15:36 +0000 Subject: [PATCH 19/21] fix test. --- .../org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 0d5dc7af5f522..75625f7922823 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -59,7 +59,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont // TODO unify the error code try { context.sparkContext.setJobDescription(command) - val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) + val execution = context.sessionState.executePlan(context.sql(command).queryExecution.analyzed) hiveResponse = execution.hiveResultString() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) From 8314cc310d9cf5d807a7e9b9de3c962dc37bf3e8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 May 2017 04:04:24 +0000 Subject: [PATCH 20/21] Create a new field in Dataset for the plan with barrier. --- .../apache/spark/sql/DataFrameWriter.scala | 9 +- .../scala/org/apache/spark/sql/Dataset.scala | 107 +++++++++--------- .../spark/sql/RelationalGroupedDataset.scala | 15 ++- .../spark/sql/execution/CacheManager.scala | 6 +- .../command/AnalyzeColumnCommand.scala | 2 +- .../execution/datasources/DataSource.scala | 8 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../execution/streaming/FileStreamSink.scala | 2 +- .../sql/execution/streaming/ForeachSink.scala | 2 +- .../streaming/ProgressReporter.scala | 2 +- .../execution/streaming/StreamExecution.scala | 2 +- .../sql/streaming/DataStreamWriter.scala | 2 +- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 6 +- .../org/apache/spark/sql/QueryTest.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 6 +- .../columnar/InMemoryColumnarQuerySuite.scala | 11 +- .../execution/joins/ExistenceJoinSuite.scala | 3 +- .../sql/execution/joins/InnerJoinSuite.scala | 3 +- .../sql/execution/joins/OuterJoinSuite.scala | 3 +- .../streaming/EventTimeWatermarkSuite.scala | 2 +- .../FlatMapGroupsWithStateSuite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../sql/streaming/StreamingQuerySuite.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 2 +- .../spark/sql/hive/ListTablesSuite.scala | 2 +- .../execution/AggregationQuerySuite.scala | 2 +- .../sql/hive/execution/HiveUDFSuite.scala | 4 +- 28 files changed, 104 insertions(+), 109 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b50e602f31aa9..b71c5eb843eec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -232,7 +232,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { runCommand(df.sparkSession, "save") { SaveIntoDataSourceCommand( - query = df.queryExecution.analyzed, + query = df.logicalPlan, provider = source, partitionColumns = partitioningColumns.getOrElse(Nil), options = extraOptions.toMap, @@ -284,7 +284,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { InsertIntoTable( table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], - query = df.queryExecution.analyzed, + query = df.logicalPlan, overwrite = mode == SaveMode.Overwrite, ifPartitionNotExists = false) } @@ -370,7 +370,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { case (true, SaveMode.Overwrite) => // Get all input data source or hive relations of the query. - val srcRelations = df.queryExecution.analyzed.collect { + val srcRelations = df.logicalPlan.collect { case LogicalRelation(src: BaseRelation, _, _) => src case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) => relation.tableMeta.identifier @@ -417,8 +417,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, - Some(df.queryExecution.analyzed))) + runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d837aa2359d2c..c3bedc65e1754 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -174,10 +174,10 @@ class Dataset[T] private[sql]( this(sqlContext.sparkSession, logicalPlan, encoder) } - @transient private val logicalPlan: LogicalPlan = { + @transient private[sql] val logicalPlan: LogicalPlan = { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. - val analyzed = queryExecution.analyzed match { + queryExecution.analyzed match { case c: Command => LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => @@ -185,10 +185,11 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } - // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - AnalysisBarrier(analyzed) } + // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. + @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -205,7 +206,7 @@ class Dataset[T] private[sql]( * `fromRow` method later. */ private val boundEnc = - exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) + exprEnc.resolveAndBind(planWithBarrier.output, sparkSession.sessionState.analyzer) private implicit def classTag = exprEnc.clsTag @@ -415,7 +416,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -437,7 +438,7 @@ class Dataset[T] private[sql]( s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + s"New column names (${colNames.size}): " + colNames.mkString(", ")) - val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => + val newCols = planWithBarrier.output.zip(colNames).map { case (oldAttribute, newName) => Column(oldAttribute).as(newName) } select(newCols : _*) @@ -524,7 +525,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def isStreaming: Boolean = logicalPlan.isStreaming + def isStreaming: Boolean = planWithBarrier.isStreaming /** * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate @@ -575,7 +576,7 @@ class Dataset[T] private[sql]( Dataset.ofRows( sparkSession, LogicalRDD( - logicalPlan.output, + planWithBarrier.output, internalRdd, outputPartitioning, physicalPlan.outputOrdering @@ -618,7 +619,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) } /** @@ -792,7 +793,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) } /** @@ -870,7 +871,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -931,7 +932,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -940,8 +941,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed + val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -973,7 +974,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) } /** @@ -1005,8 +1006,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.logicalPlan, - other.logicalPlan, + this.planWithBarrier, + other.planWithBarrier, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1176,7 +1177,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { - UnresolvedHint(name, parameters, logicalPlan) + UnresolvedHint(name, parameters, planWithBarrier) } /** @@ -1202,7 +1203,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + SubqueryAlias(alias, planWithBarrier) } /** @@ -1240,7 +1241,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), logicalPlan) + Project(cols.map(_.named), planWithBarrier) } /** @@ -1295,8 +1296,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, - logicalPlan) + val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, + planWithBarrier) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1314,8 +1315,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, logicalPlan.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) + columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1391,7 +1392,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + Filter(condition.expr, planWithBarrier) } /** @@ -1568,7 +1569,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = logicalPlan + val inputPlan = planWithBarrier val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1714,7 +1715,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + Limit(Literal(n), planWithBarrier) } /** @@ -1743,8 +1744,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(queryExecution.analyzed, other.queryExecution.analyzed)) - .mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) } /** @@ -1758,7 +1758,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(logicalPlan, other.logicalPlan) + Intersect(planWithBarrier, other.planWithBarrier) } /** @@ -1772,7 +1772,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(logicalPlan, other.logicalPlan) + Except(planWithBarrier, other.planWithBarrier) } /** @@ -1793,7 +1793,7 @@ class Dataset[T] private[sql]( s"Fraction must be nonnegative, but got ${fraction}") withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + Sample(0.0, fraction, withReplacement, seed, planWithBarrier)() } } @@ -1835,15 +1835,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = logicalPlan.output + val sortOrder = planWithBarrier.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, logicalPlan) + Sort(sortOrder, global = false, planWithBarrier) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - logicalPlan + planWithBarrier } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -1927,7 +1927,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -1968,7 +1968,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -2083,7 +2083,7 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.logicalPlan.output + val attrs = this.planWithBarrier.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) @@ -2131,7 +2131,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, logicalPlan, isStreaming) + Deduplicate(groupCols, planWithBarrier, isStreaming) } /** @@ -2280,7 +2280,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2294,7 +2294,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2308,7 +2308,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + MapElements[T, U](func, planWithBarrier) } /** @@ -2323,7 +2323,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, logicalPlan)) + withTypedPlan(MapElements[T, U](func, planWithBarrier)) } /** @@ -2339,7 +2339,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, logicalPlan), + MapPartitions[T, U](func, planWithBarrier), implicitly[Encoder[U]]) } @@ -2370,7 +2370,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) } /** @@ -2525,7 +2525,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + Repartition(numPartitions, shuffle = true, planWithBarrier) } /** @@ -2539,7 +2539,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) } /** @@ -2555,7 +2555,8 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression( - partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) + partitionExprs.map(_.expr), planWithBarrier, + sparkSession.sessionState.conf.numShufflePartitions) } /** @@ -2576,7 +2577,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + Repartition(numPartitions, shuffle = false, planWithBarrier) } /** @@ -2665,7 +2666,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserialized = CatalystSerde.deserialize[T](planWithBarrier) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } @@ -2764,7 +2765,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = logicalPlan, + child = planWithBarrier, allowExisting = false, replace = replace, viewType = viewType) @@ -2935,7 +2936,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, logicalPlan) + Sort(sortOrder, global = global, planWithBarrier) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a378914ca304e..64755434784a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -58,22 +58,21 @@ class RelationalGroupedDataset protected[sql]( } val aliasedAgg = aggregates.map(alias) - val logicalPlan = df.queryExecution.analyzed groupType match { case RelationalGroupedDataset.GroupByType => Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, logicalPlan)) + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, logicalPlan)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, logicalPlan)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, logicalPlan)) + df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } } @@ -224,7 +223,7 @@ class RelationalGroupedDataset protected[sql]( def agg(expr: Column, exprs: Column*): DataFrame = { toDF((expr +: exprs).map { case typed: TypedColumn[_, _] => - typed.withInputType(df.exprEnc, df.queryExecution.analyzed.output).expr + typed.withInputType(df.exprEnc, df.logicalPlan.output).expr case c => c.expr }) } @@ -429,8 +428,8 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.deserializer, df.exprEnc.schema, groupingAttributes, - df.queryExecution.analyzed.output, - df.queryExecution.analyzed)) + df.logicalPlan.output, + df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 1e89a8f8a8dc1..0ea806d6cb50b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -90,7 +90,7 @@ class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.logicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { @@ -110,7 +110,7 @@ class CacheManager extends Logging { * Un-cache all the cache entries that refer to the given plan. */ def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - uncacheQuery(query.sparkSession, query.queryExecution.analyzed, blocking) + uncacheQuery(query.sparkSession, query.logicalPlan, blocking) } /** @@ -159,7 +159,7 @@ class CacheManager extends Logging { /** Optionally returns cached data for the given [[Dataset]] */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.queryExecution.analyzed) + lookupCachedData(query.logicalPlan) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index a47d58aa8c4f8..0d8db2ff5d5a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -70,7 +70,7 @@ case class AnalyzeColumnCommand( tableIdent: TableIdentifier, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { - val relation = sparkSession.table(tableIdent).queryExecution.analyzed + val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = columnNames.map { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index ac0b590e0eee6..9fce29b06b9d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -412,7 +412,7 @@ case class DataSource( // not need to have the query as child, to avoid to analyze an optimized query, // because InsertIntoHadoopFsRelationCommand will be optimized first. val partitionAttributes = partitionColumns.map { name => - val plan = data.queryExecution.analyzed + val plan = data.logicalPlan plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") @@ -424,8 +424,8 @@ case class DataSource( }.head } // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.queryExecution.analyzed (partition columns are all moved after data column). - // This will be adjusted within InsertIntoHadoopFsRelation. + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( outputPath = outputPath, @@ -435,7 +435,7 @@ case class DataSource( bucketSpec = bucketSpec, fileFormat = format, options = options, - query = data.queryExecution.analyzed, + query = data.logicalPlan, mode = mode, catalogTable = catalogTable, fileIndex = fileIndex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index ade43a2863fb6..f8d4a9bb5b81a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -90,7 +90,7 @@ case class CreateTempViewUsing( val catalog = sparkSession.sessionState.catalog val viewDefinition = Dataset.ofRows( - sparkSession, LogicalRelation(dataSource.resolveRelation())).queryExecution.analyzed + sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan if (global) { catalog.createGlobalTempView(tableIdent.table, viewDefinition, replace) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index e83ed66aadd73..6885d0bf67ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -115,7 +115,7 @@ class FileStreamSink( // the given columns names. val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col => val nameEquality = data.sparkSession.sessionState.conf.resolver - data.queryExecution.analyzed.output.find(f => nameEquality(f.name, col)).getOrElse { + data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index 019422d4ec940..de09fb568d2a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -44,7 +44,7 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria // // Hence, we need to manually convert internal rows to objects using encoder. val encoder = encoderFor[T].resolveAndBind( - data.queryExecution.analyzed.output, + data.logicalPlan.output, data.sparkSession.sessionState.analyzer) data.queryExecution.toRdd.foreachPartition { iter => if (writer.open(TaskContext.getPartitionId(), batchId)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 17040406c8cd9..a4e4ca821374c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -234,7 +234,7 @@ trait ProgressReporter extends Logging { // 3. For each source, we sum the metrics of the associated execution plan leaves. // val logicalPlanLeafToSource = newData.flatMap { case (source, df) => - df.queryExecution.analyzed.collectLeaves().map { leaf => leaf -> source } + df.logicalPlan.collectLeaves().map { leaf => leaf -> source } } val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3a6ad80376555..b6ddf7437ea13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -617,7 +617,7 @@ class StreamExecution( val withNewSources = logicalPlan transform { case StreamingExecutionRelation(source, output) => newData.get(source).map { data => - val newPlan = data.queryExecution.analyzed + val newPlan = data.logicalPlan assert(output.size == newPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + s"${Utils.truncatedString(newPlan.output, ",")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 14525249c2edc..0d2611f9bbcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -359,7 +359,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * need to care about case sensitivity afterwards. */ private def normalize(columnName: String, columnType: String): String = { - val validColumnNames = df.queryExecution.analyzed.output.map(_.name) + val validColumnNames = df.logicalPlan.output.map(_.name) validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName)) .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + s"existing columns (${validColumnNames.mkString(", ")})")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index be493baf4fc0b..4114f7a19c7ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -189,7 +189,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } test("SPARK-1669: cacheTable should be idempotent") { - assume(!spark.table("testData").queryExecution.analyzed.isInstanceOf[InMemoryRelation]) + assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) spark.catalog.cacheTable("testData") assertCached(spark.table("testData")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5b7624511c008..8eb381b91f46d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1043,10 +1043,10 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) val cp = ds.checkpoint(eager) - val logicalRDD = cp.queryExecution.analyzed match { + val logicalRDD = cp.logicalPlan match { case plan: LogicalRDD => plan case _ => - val treeString = cp.queryExecution.analyzed.treeString(verbose = true) + val treeString = cp.logicalPlan.treeString(verbose = true) fail(s"Expecting a LogicalRDD, but got\n$treeString") } @@ -1118,7 +1118,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // instead of Int for avoiding possible overflow. val ds = (0 to 10000).map( i => (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() - val sizeInBytes = ds.queryExecution.analyzed.stats(sqlConf).sizeInBytes + val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes // sizeInBytes is 2404280404, before the fix, it overflows to a negative number assert(sizeInBytes > 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 2eb907bbb95b5..f9808834df4a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -253,7 +253,7 @@ object QueryTest { df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Option[String] = { - val isSorted = df.queryExecution.analyzed.collect { case s: logical.Sort => s }.nonEmpty + val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty if (checkToRDD) { df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5c3921c7bb448..d02c8ffe33f0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -168,21 +168,21 @@ class PlannerSuite extends SharedSQLContext { val query = testData.select('key, 'value).sort('key).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) - assert(planned.output === testData.select('key, 'value).queryExecution.analyzed.output) + assert(planned.output === testData.select('key, 'value).logicalPlan.output) } test("terminal limit -> project -> sort should use TakeOrderedAndProject") { val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) - assert(planned.output === testData.select('value, 'key).queryExecution.analyzed.output) + assert(planned.output === testData.select('value, 'key).logicalPlan.output) } test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { val query = testData.select('value).limit(2) val planned = query.queryExecution.sparkPlan assert(planned.isInstanceOf[CollectLimitExec]) - assert(planned.output === testData.select('value).queryExecution.analyzed.output) + assert(planned.output === testData.select('value).logicalPlan.output) } test("TakeOrderedAndProject can appear in the middle of plans") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 9a627f3412303..109b1d9db60d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -38,7 +38,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { private def cachePrimitiveTest(data: DataFrame, dataType: String) { data.createOrReplaceTempView(s"testData$dataType") val storageLevel = MEMORY_ONLY - val plan = spark.sessionState.executePlan(data.queryExecution.analyzed).sparkPlan + val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) @@ -114,7 +114,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("simple columnar query") { - val plan = spark.sessionState.executePlan(testData.queryExecution.analyzed).sparkPlan + val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -131,8 +131,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("projection") { - val plan = spark.sessionState.executePlan(testData.select('value, 'key).queryExecution.analyzed) - .sparkPlan + val plan = spark.sessionState.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -148,7 +147,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = spark.sessionState.executePlan(testData.queryExecution.analyzed).sparkPlan + val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -322,7 +321,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() - val plan = spark.sessionState.executePlan(data.queryExecution.analyzed).sparkPlan + val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) // Materialize the data. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index e7febd2e3a56b..38377164c10e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -85,8 +85,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Row]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.queryExecution.analyzed, rightRows.queryExecution.analyzed, Inner, - Some(condition)) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) ExtractEquiJoinKeys.unapply(join) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index dabb02a1e2d79..4408ece112258 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -80,8 +80,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.queryExecution.analyzed, rightRows.queryExecution.analyzed, Inner, - Some(condition())) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) ExtractEquiJoinKeys.unapply(join) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index bb3fb558d5485..001feb0f2b399 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -72,8 +72,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.queryExecution.analyzed, rightRows.queryExecution.analyzed, Inner, - Some(condition)) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) ExtractEquiJoinKeys.unapply(join) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 8603d4855a13c..1b60a06ec402f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -338,7 +338,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin .withWatermark("first", "1 minute") .withWatermark("second", "2 minutes") - val eventTimeColumns = df.queryExecution.analyzed.output + val eventTimeColumns = df.logicalPlan.output .filter(_.metadata.contains(EventTimeWatermark.delayKey)) assert(eventTimeColumns.size === 1) assert(eventTimeColumns(0).name === "second") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 1c17d692b36fe..6bb9408ce99ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1034,7 +1034,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .toDS .groupByKey(x => x) .flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func) - .queryExecution.analyzed.collectFirst { + .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( f, k, v, g, d, o, None, s, m, t, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 16e21e2e7f6b9..5bc36dd30f6d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -516,7 +516,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { queryToUse.flatMap { query => findSourceIndex(query.logicalPlan) }.orElse { - findSourceIndex(stream.queryExecution.analyzed) + findSourceIndex(stream.logicalPlan) }.getOrElse { throw new IllegalArgumentException( "Could find index of the source to which data was added") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 42665d9341fd0..b69536ed37463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -441,7 +441,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi test("input row calculation with trigger input DF having multiple leaves") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) - require(streamingTriggerDF.queryExecution.analyzed.collectLeaves().size > 1) + require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF) // After the first trigger, the calculated input rows should be 10 diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 75625f7922823..0d5dc7af5f522 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -59,7 +59,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont // TODO unify the error code try { context.sparkContext.setJobDescription(command) - val execution = context.sessionState.executePlan(context.sql(command).queryExecution.analyzed) + val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) hiveResponse = execution.hiveResultString() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 2c0f7fe6d2d93..15ba61646d03f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -34,7 +34,7 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft super.beforeAll() // The catalog in HiveContext is a case insensitive one. sessionState.catalog.createTempView( - "ListTablesSuiteTable", df.queryExecution.analyzed, overrideIfExists = true) + "ListTablesSuiteTable", df.logicalPlan, overrideIfExists = true) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 73f6a3e6ac05b..84f915977bd88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1020,7 +1020,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. // todo: remove it? - val newActual = Dataset.ofRows(spark, actual.queryExecution.analyzed) + val newActual = Dataset.ofRows(spark, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 4f91fc871d63b..4446af2e75e00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -393,12 +393,12 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { // HiveSimpleUDF sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") val df1 = sql("SELECT testStringStringUDF(rand(), \"hello\")") - assert(!df1.queryExecution.analyzed.asInstanceOf[Project].projectList.forall(_.deterministic)) + assert(!df1.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) // HiveGenericUDF sql(s"CREATE TEMPORARY FUNCTION testGenericUDFHash AS '${classOf[GenericUDFHash].getName}'") val df2 = sql("SELECT testGenericUDFHash(rand())") - assert(!df2.queryExecution.analyzed.asInstanceOf[Project].projectList.forall(_.deterministic)) + assert(!df2.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) } } From 6add9ec4aad2266f08fb77ad137442fa00552560 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 May 2017 02:52:17 +0000 Subject: [PATCH 21/21] Address comments. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c3bedc65e1754..f9bd8f3d278ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -206,7 +206,7 @@ class Dataset[T] private[sql]( * `fromRow` method later. */ private val boundEnc = - exprEnc.resolveAndBind(planWithBarrier.output, sparkSession.sessionState.analyzer) + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) private implicit def classTag = exprEnc.clsTag @@ -438,7 +438,7 @@ class Dataset[T] private[sql]( s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + s"New column names (${colNames.size}): " + colNames.mkString(", ")) - val newCols = planWithBarrier.output.zip(colNames).map { case (oldAttribute, newName) => + val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => Column(oldAttribute).as(newName) } select(newCols : _*) @@ -525,7 +525,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def isStreaming: Boolean = planWithBarrier.isStreaming + def isStreaming: Boolean = logicalPlan.isStreaming /** * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate @@ -576,7 +576,7 @@ class Dataset[T] private[sql]( Dataset.ofRows( sparkSession, LogicalRDD( - planWithBarrier.output, + logicalPlan.output, internalRdd, outputPartitioning, physicalPlan.outputOrdering