diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1313ccd120c1f..4f0ba8457fd0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.BooleanType - object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = apply(BindReferences.bindReference(expression, inputSchema)) @@ -95,6 +95,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } +/** + * Optimized version of In clause, when all filter values of In clause are + * static. + */ +case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) + extends Predicate { + + def children = child + + def nullable = true // TODO: Figure out correct nullability semantics of IN. + override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + + override def eval(input: Row): Any = { + hset.contains(value.eval(input)) + } +} + case class And(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "&&" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ddd4b3755d629..9500c12125258 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -38,7 +39,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + OptimizedIn) :: Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, @@ -225,6 +227,22 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]] + * which is much faster + */ +object OptimizedIn extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) + => { + val hSet = list.map(e => e.eval(null)) + InSet(v, HashSet() ++ hSet, v +: list) + } + } + } +} + /** * Simplifies boolean expressions where the answer can be determined without evaluating both sides. * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index f1df817c41362..75eef398e7267 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import scala.collection.immutable.HashSet + import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types._ + /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -136,6 +139,24 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) } + test("INSET") { + val hS = HashSet[Any]() + 1 + 2 + val nS = HashSet[Any]() + 1 + 2 + null + val one = Literal(1) + val two = Literal(2) + val three = Literal(3) + val nl = Literal(null) + val s = Seq(one, two) + val nullS = Seq(one, two, null) + checkEvaluation(InSet(one, hS, one +: s), true) + checkEvaluation(InSet(two, hS, two +: s), true) + checkEvaluation(InSet(two, nS, two +: nullS), true) + checkEvaluation(InSet(nl, nS, nl +: nullS), true) + checkEvaluation(InSet(three, hS, three +: s), false) + checkEvaluation(InSet(three, nS, three +: nullS), false) + checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) + } + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizedInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizedInSuite.scala new file mode 100644 index 0000000000000..b8a885f39d421 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizedInSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types._ + +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class OptimizedInSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification, + OptimizedIn) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("OptimizedIn test: In clause optimized to InSet") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2, + UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: In clause not optimized in case filter has attributes") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + comparePlans(optimized, correctAnswer) + } +}