From 86fee3635d5c2dc6c7381ae8739509e50993d4e7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 22 Jul 2015 23:36:04 +0800 Subject: [PATCH 1/6] add initialization phase for nondeterministic expression --- .../sql/catalyst/analysis/CheckAnalysis.scala | 195 +++++++++--------- .../sql/catalyst/expressions/Expression.scala | 21 +- .../sql/catalyst/expressions/random.scala | 12 +- .../expressions/ExpressionEvalHelper.scala | 4 + .../spark/sql/execution/SparkPlan.scala | 12 ++ .../MonotonicallyIncreasingID.scala | 13 +- .../expressions/SparkPartitionID.scala | 8 +- 7 files changed, 152 insertions(+), 113 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 81d473c1130f7..05fd0a246fe6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -38,114 +38,105 @@ trait CheckAnalysis { throw new AnalysisException(msg) } - def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { - case e: Generator => true - }).nonEmpty + case e: Generator => e + }).length > 1 } def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. - plan.foreachUp { - - case operator: LogicalPlan => - operator transformExpressionsUp { - case a: Attribute if !a.resolved => - val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") - - case e: Expression if e.checkInputDataTypes().isFailure => - e.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckFailure(message) => - e.failAnalysis( - s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") - } - - case c: Cast if !c.resolved => - failAnalysis( - s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - - case WindowExpression(UnresolvedWindowFunction(name, _), _) => - failAnalysis( - s"Could not resolve window function '$name'. " + - "Note that, using window functions currently requires a HiveContext") - - case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => - // The window spec is not valid. - val reason = windowSpec.validate.get - failAnalysis(s"Window specification $windowSpec is not valid because $reason") - } - - operator match { - case f: Filter if f.condition.dataType != BooleanType => - failAnalysis( - s"filter expression '${f.condition.prettyString}' " + - s"of type ${f.condition.dataType.simpleString} is not a boolean.") - - case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => - failAnalysis( - s"join condition '${condition.prettyString}' " + - s"of type ${condition.dataType.simpleString} is not a boolean.") - - case Aggregate(groupingExprs, aggregateExprs, child) => - def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK - case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => - failAnalysis( - s"expression '${e.prettyString}' is neither present in the group by, " + - s"nor is it an aggregate function. " + - "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.exists(_.semanticEquals(e)) => // OK - case e if e.references.isEmpty => // OK - case e => e.children.foreach(checkValidAggregateExpression) - } - - aggregateExprs.foreach(checkValidAggregateExpression) - - case Sort(orders, _, _) => - orders.foreach { order => - order.dataType match { - case t: AtomicType => // OK - case NullType => // OK - case t => - failAnalysis(s"Sorting is not supported for columns of type ${t.simpleString}") - } - } - - case _ => // Fallbacks to the following checks - } - - operator match { - case o if o.children.nonEmpty && o.missingInput.nonEmpty => - val missingAttributes = o.missingInput.mkString(",") - val input = o.inputSet.mkString(",") - - failAnalysis( - s"resolved attribute(s) $missingAttributes missing from $input " + - s"in operator ${operator.simpleString}") - - case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => - failAnalysis( - s"""Only a single table generating function is allowed in a SELECT clause, found: - | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) - - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - failAnalysis( - s""" - |Failure when resolving conflicting references in Join: - |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) - - case o if !o.resolved => - failAnalysis( - s"unresolved operator ${operator.simpleString}") - - case _ => // Analysis successful! - } + plan.foreachUp { operator => + operator transformExpressionsUp { + case a: Attribute if !a.resolved => + val from = operator.inputSet.map(_.name).mkString(", ") + a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } + + case c: Cast if !c.resolved => + failAnalysis( + s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + + case WindowExpression(UnresolvedWindowFunction(name, _), _) => + failAnalysis( + s"Could not resolve window function '$name'. " + + "Note that, using window functions currently requires a HiveContext") + + case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => + // The window spec is not valid. + val reason = windowSpec.validate.get + failAnalysis(s"Window specification $windowSpec is not valid because $reason") + } + + operator match { + case f: Filter if f.condition.dataType != BooleanType => + failAnalysis( + s"filter expression '${f.condition.prettyString}' " + + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + + case Aggregate(groupingExprs, aggregateExprs, child) => + def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK + case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => + failAnalysis( + s"expression '${e.prettyString}' is neither present in the group by, " + + s"nor is it an aggregate function. " + + "Add to group by or wrap in first() if you don't care which value you get.") + case e if groupingExprs.exists(_.semanticEquals(e)) => // OK + case e if e.references.isEmpty => // OK + case e => e.children.foreach(checkValidAggregateExpression) + } + + aggregateExprs.foreach(checkValidAggregateExpression) + + case _ => // Fallbacks to the following checks + } + + operator match { + case o if o.children.nonEmpty && o.missingInput.nonEmpty => + val missingAttributes = o.missingInput.mkString(",") + val input = o.inputSet.mkString(",") + + failAnalysis( + s"resolved attribute(s) $missingAttributes missing from $input " + + s"in operator ${operator.simpleString}") + + case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => + failAnalysis( + s"""Only a single table generating function is allowed in a SELECT clause, found: + | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") + + case o if o.expressions.exists(!_.deterministic) && + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + failAnalysis( + s"""nondeterministic expressions are only allowed in Project or Filter, found: + | ${o.expressions.map(_.prettyString).mkString(",")} + |in operator ${operator.simpleString} + """.stripMargin) + + case _ => // Analysis successful! + } } extendedCheckRules.foreach(_(plan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3f72e6e184db1..cb4c3f24b2721 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -196,7 +196,26 @@ trait Unevaluable extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - override def deterministic: Boolean = false + final override def deterministic: Boolean = false + final override def foldable: Boolean = false + + private[this] var initialized = false + + final def initialize(): Unit = { + if (!initialized) { + initInternal() + initialized = true + } + } + + protected def initInternal(): Unit + + final override def eval(input: InternalRow = null): Any = { + require(initialized, "nondeterministic expression should be initialized before evaluate") + evalInternal(input) + } + + protected def evalInternal(input: InternalRow): Any } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index aef24a5486466..8f30519697a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, the Random Number Generator is - * reset every time we serialize and deserialize it. + * reset every time we serialize and deserialize and initialize it. */ - @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + @transient protected var rng: XORShiftRandom = _ + + override protected def initInternal(): Unit = { + rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + } override def nullable: Boolean = false @@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ case class Rand(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextDouble() + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() def this() = this(Utils.random.nextLong()) @@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG { /** Generate a random column with i.i.d. gaussian random distribution. */ case class Randn(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextGaussian() + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() def this() = this(Utils.random.nextLong()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 4930219aa63cb..852a8b235f127 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -64,6 +64,10 @@ trait ExpressionEvalHelper { } protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + expression.foreach { + case n: Nondeterministic => n.initialize() + case _ => + } expression.eval(inputRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 50c27def8ea54..dce724c1aafdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -194,6 +194,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } } else { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) new InterpretedProjection(expressions, inputSchema) } } @@ -216,6 +220,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } } else { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) () => new InterpretedMutableProjection(expressions, inputSchema) } } @@ -235,6 +243,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } } else { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) InterpretedPredicate.create(expression, inputSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 2645eb1854bce..eca36b3274420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -37,17 +37,22 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with /** * Record ID within each partition. By being transient, count's value is reset to 0 every time - * we serialize and deserialize it. + * we serialize and deserialize and initialize it. */ - @transient private[this] var count: Long = 0L + @transient private[this] var count: Long = _ - @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + @transient private[this] var partitionMask: Long = _ + + override protected def initInternal(): Unit = { + count = 0L + partitionMask = TaskContext.getPartitionId().toLong << 33 + } override def nullable: Boolean = false override def dataType: DataType = LongType - override def eval(input: InternalRow): Long = { + override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 53ddd47e3e0c1..61ef079d89af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -33,9 +33,13 @@ private[sql] case object SparkPartitionID extends LeafExpression with Nondetermi override def dataType: DataType = IntegerType - @transient private lazy val partitionId = TaskContext.getPartitionId() + @transient private[this] var partitionId: Int = _ - override def eval(input: InternalRow): Int = partitionId + override protected def initInternal(): Unit = { + partitionId = TaskContext.getPartitionId() + } + + override protected def evalInternal(input: InternalRow): Int = partitionId override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val idTerm = ctx.freshName("partitionId") From b4a4fc77647031cc5edbbe880db12137f2d51610 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Jul 2015 00:47:19 +0800 Subject: [PATCH 2/6] revert a refactor --- .../sql/catalyst/analysis/CheckAnalysis.scala | 179 +++++++++--------- 1 file changed, 90 insertions(+), 89 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 05fd0a246fe6b..a02183dd3dd7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -47,96 +46,98 @@ trait CheckAnalysis { def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. - plan.foreachUp { operator => - operator transformExpressionsUp { - case a: Attribute if !a.resolved => - val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") - - case e: Expression if e.checkInputDataTypes().isFailure => - e.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckFailure(message) => - e.failAnalysis( - s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") - } - - case c: Cast if !c.resolved => - failAnalysis( - s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - - case WindowExpression(UnresolvedWindowFunction(name, _), _) => - failAnalysis( - s"Could not resolve window function '$name'. " + - "Note that, using window functions currently requires a HiveContext") - - case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => - // The window spec is not valid. - val reason = windowSpec.validate.get - failAnalysis(s"Window specification $windowSpec is not valid because $reason") - } - - operator match { - case f: Filter if f.condition.dataType != BooleanType => - failAnalysis( - s"filter expression '${f.condition.prettyString}' " + - s"of type ${f.condition.dataType.simpleString} is not a boolean.") - - case Aggregate(groupingExprs, aggregateExprs, child) => - def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK - case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => - failAnalysis( - s"expression '${e.prettyString}' is neither present in the group by, " + - s"nor is it an aggregate function. " + - "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.exists(_.semanticEquals(e)) => // OK - case e if e.references.isEmpty => // OK - case e => e.children.foreach(checkValidAggregateExpression) - } - - aggregateExprs.foreach(checkValidAggregateExpression) - - case _ => // Fallbacks to the following checks - } - - operator match { - case o if o.children.nonEmpty && o.missingInput.nonEmpty => - val missingAttributes = o.missingInput.mkString(",") - val input = o.inputSet.mkString(",") - - failAnalysis( - s"resolved attribute(s) $missingAttributes missing from $input " + - s"in operator ${operator.simpleString}") - - case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => - failAnalysis( - s"""Only a single table generating function is allowed in a SELECT clause, found: - | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) - - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - failAnalysis( - s""" - |Failure when resolving conflicting references in Join: - |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) - - case o if !o.resolved => - failAnalysis( - s"unresolved operator ${operator.simpleString}") - - case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => - failAnalysis( - s"""nondeterministic expressions are only allowed in Project or Filter, found: - | ${o.expressions.map(_.prettyString).mkString(",")} - |in operator ${operator.simpleString} + plan.foreachUp { + + case operator: LogicalPlan => + operator transformExpressionsUp { + case a: Attribute if !a.resolved => + val from = operator.inputSet.map(_.name).mkString(", ") + a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } + + case c: Cast if !c.resolved => + failAnalysis( + s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + + case WindowExpression(UnresolvedWindowFunction(name, _), _) => + failAnalysis( + s"Could not resolve window function '$name'. " + + "Note that, using window functions currently requires a HiveContext") + + case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => + // The window spec is not valid. + val reason = windowSpec.validate.get + failAnalysis(s"Window specification $windowSpec is not valid because $reason") + } + + operator match { + case f: Filter if f.condition.dataType != BooleanType => + failAnalysis( + s"filter expression '${f.condition.prettyString}' " + + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + + case Aggregate(groupingExprs, aggregateExprs, child) => + def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK + case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => + failAnalysis( + s"expression '${e.prettyString}' is neither present in the group by, " + + s"nor is it an aggregate function. " + + "Add to group by or wrap in first() if you don't care which value you get.") + case e if groupingExprs.exists(_.semanticEquals(e)) => // OK + case e if e.references.isEmpty => // OK + case e => e.children.foreach(checkValidAggregateExpression) + } + + aggregateExprs.foreach(checkValidAggregateExpression) + + case _ => // Fallbacks to the following checks + } + + operator match { + case o if o.children.nonEmpty && o.missingInput.nonEmpty => + val missingAttributes = o.missingInput.mkString(",") + val input = o.inputSet.mkString(",") + + failAnalysis( + s"resolved attribute(s) $missingAttributes missing from $input " + + s"in operator ${operator.simpleString}") + + case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => + failAnalysis( + s"""Only a single table generating function is allowed in a SELECT clause, found: + | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") + + case o if o.expressions.exists(!_.deterministic) && + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + failAnalysis( + s"""nondeterministic expressions are only allowed in Project or Filter, found: + | ${o.expressions.map(_.prettyString).mkString(",")} + |in operator ${operator.simpleString} """.stripMargin) - case _ => // Analysis successful! - } + case _ => // Analysis successful! + } } extendedCheckRules.foreach(_(plan)) } From bb7d83812f5a20cdf7771bbeb70dbcbb80214855 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Jul 2015 14:11:44 +0800 Subject: [PATCH 3/6] pulls out nondeterministic expressions into a project --- .../sql/catalyst/analysis/Analyzer.scala | 31 +++++++++++++++++-- .../plans/logical/basicOperators.scala | 1 - 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e916887187dc8..7bdd029ed4a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer @@ -78,7 +79,9 @@ class Analyzer( GlobalAggregates :: UnresolvedHavingClauseAttributes :: HiveTypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*) + extendedResolutionRules : _*), + Batch("Nondeterministic", Once, + PullOutNondeterministic) ) /** @@ -910,6 +913,30 @@ class Analyzer( Project(finalProjectList, withWindow) } } + + /** + * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, + * put them into an inner Project and finally project them at the outer Project. + */ + object PullOutNondeterministic extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Project => p + case f: Filter => f + case p: UnaryNode if p.expressions.exists(!_.deterministic) => + val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + new TreeNodeRef(e) -> ne + }.toMap + val newPlan = p.transformExpressions { case e => + nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + } + val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) + Project(p.output, newPlan.withNewChildren(newChild :: Nil)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 57a12820fa4c6..afc5ab4a70460 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet From 9eac85e3e3b5bc74b0a6c249bf246db25f762059 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Jul 2015 16:10:05 +0800 Subject: [PATCH 4/6] move init code to interpreted class --- .../spark/sql/catalyst/expressions/Projection.scala | 10 ++++++++++ .../spark/sql/catalyst/expressions/predicates.scala | 4 ++++ .../org/apache/spark/sql/execution/SparkPlan.scala | 12 ------------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index fb873e7e99547..c1ed9cf7ed6a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -31,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) + // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -57,6 +62,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) + private[this] val exprArray = expressions.toArray private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) def currentValue: InternalRow = mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3f1bd2a925fe7..5bfe1cad24a3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -30,6 +30,10 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { + expression.foreach { + case n: Nondeterministic => n.initialize() + case _ => + } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index dce724c1aafdb..50c27def8ea54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -194,10 +194,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } } else { - expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() - case _ => - }) new InterpretedProjection(expressions, inputSchema) } } @@ -220,10 +216,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } } else { - expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() - case _ => - }) () => new InterpretedMutableProjection(expressions, inputSchema) } } @@ -243,10 +235,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } } else { - expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() - case _ => - }) InterpretedPredicate.create(expression, inputSchema) } } From ef68ff4268a8d14acbf8465e987db422173600b8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Jul 2015 16:20:05 +0800 Subject: [PATCH 5/6] fix comments --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7bdd029ed4a17..13577be214d7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -915,8 +915,8 @@ class Analyzer( } /** - * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, - * put them into an inner Project and finally project them at the outer Project. + * Pulls out nondeterministic expressions from unary LogicalPlan which is not Project or Filter, + * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { From 6c6f332db25335e61b06188e094a688f3886cae9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 26 Jul 2015 00:05:35 +0800 Subject: [PATCH 6/6] add test --- .../sql/catalyst/analysis/Analyzer.scala | 8 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 21 +++- .../plans/logical/basicOperators.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 96 +++++++--------- .../sql/catalyst/analysis/AnalysisTest.scala | 105 ++++++++++++++++++ 5 files changed, 171 insertions(+), 61 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 13577be214d7f..a723e92114b32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -915,14 +915,18 @@ class Analyzer( } /** - * Pulls out nondeterministic expressions from unary LogicalPlan which is not Project or Filter, + * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Project => p case f: Filter => f - case p: UnaryNode if p.expressions.exists(!_.deterministic) => + + // todo: It's hard to write a general rule to pull out nondeterministic expressions + // from LogicalPlan, currently we only do it for UnaryNode which has same output + // schema with its child. + case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e => val ne = e match { case n: NamedExpression => n diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a02183dd3dd7a..a373714832962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -82,6 +82,11 @@ trait CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + failAnalysis( + s"join condition '${condition.prettyString}' " + + s"of type ${condition.dataType.simpleString} is not a boolean.") + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK @@ -97,6 +102,16 @@ trait CheckAnalysis { aggregateExprs.foreach(checkValidAggregateExpression) + case Sort(orders, _, _) => + orders.foreach { order => + order.dataType match { + case t: AtomicType => // OK + case NullType => // OK + case t => + failAnalysis(s"Sorting is not supported for columns of type ${t.simpleString}") + } + } + case _ => // Fallbacks to the following checks } @@ -121,8 +136,8 @@ trait CheckAnalysis { s""" |Failure when resolving conflicting references in Join: |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) case o if !o.resolved => failAnalysis( @@ -133,7 +148,7 @@ trait CheckAnalysis { failAnalysis( s"""nondeterministic expressions are only allowed in Project or Filter, found: | ${o.expressions.map(_.prettyString).mkString(",")} - |in operator ${operator.simpleString} + |in operator ${operator.simpleString} """.stripMargin) case _ => // Analysis successful! diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index afc5ab4a70460..8e1a236e2988c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -378,7 +378,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { - val limit = limitExpr.eval(null).asInstanceOf[Int] + val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum Statistics(sizeInBytes = sizeInBytes) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7e67427237a65..ed645b618dc9b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -28,6 +24,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +// todo: remove this and use AnalysisTest instead. object AnalysisSuite { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -55,7 +52,7 @@ object AnalysisSuite { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) val nestedRelation = LocalRelation( @@ -81,8 +78,7 @@ object AnalysisSuite { } -class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { - import AnalysisSuite._ +class AnalysisSuite extends AnalysisTest { test("union project *") { val plan = (1 to 100) @@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyzer.execute(plan).resolved) + assertAnalysisSuccess(plan) } test("check project's resolved") { @@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { } test("analyze project") { - assert( - caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) === - Project(testRelation.output, testRelation)) - - assert( - caseSensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - val e = intercept[AnalysisException] { - caseSensitiveAnalyze( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) - } - assert(e.getMessage().toLowerCase.contains("cannot resolve")) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) + checkAnalysis( + Project(Seq(UnresolvedAttribute("a")), testRelation), + Project(testRelation.output, testRelation)) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation)) + + assertAnalysisError( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Seq("cannot resolve")) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) } test("resolve relations") { - val e = intercept[RuntimeException] { - caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) - } - assert(e.getMessage == "Table Not Found: tAbLe") + assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe")) - assert( - caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false) } - test("divide should be casted into fractional types") { - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyzer.execute( testRelation2.select( 'a / Literal(2) as 'div1, @@ -170,10 +145,21 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList + // StringType will be promoted into Double assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double + assert(pl(3).dataType == DoubleType) assert(pl(4).dataType == DoubleType) } + + test("pull out nondeterministic expressions from unary LogicalPlan") { + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val projected = Alias(Rand(33), "_nondeterministic")() + val expected = + Project(testRelation.output, + RepartitionByExpression(Seq(projected.toAttribute), + Project(testRelation.output :+ projected, testRelation))) + checkAnalysis(plan, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala new file mode 100644 index 0000000000000..fdb4f28950daf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.types._ + +trait AnalysisTest extends PlanTest { + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))(), + AttributeReference("e", ShortType)()) + + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + + val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { + val caseSensitiveConf = new SimpleCatalystConf(true) + val caseInsensitiveConf = new SimpleCatalystConf(false) + + val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) + val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) + + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } -> + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + } + + protected def getAnalyzer(caseSensitive: Boolean) = { + if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer + } + + protected def checkAnalysis( + inputPlan: LogicalPlan, + expectedPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + val actualPlan = analyzer.execute(inputPlan) + analyzer.checkAnalysis(actualPlan) + comparePlans(actualPlan, expectedPlan) + } + + protected def assertAnalysisSuccess( + inputPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + + protected def assertAnalysisError( + inputPlan: LogicalPlan, + expectedErrors: Seq[String], + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + // todo: make sure we throw AnalysisException during analysis + val e = intercept[Exception] { + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + expectedErrors.forall(e.getMessage.contains) + } +}