Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.BooleanType


object InterpretedPredicate {
def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
apply(BindReferences.bindReference(expression, inputSchema))
Expand Down Expand Up @@ -95,6 +95,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}

/**
* Optimized version of In clause, when all filter values of In clause are
* static.
*/
case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression])
extends Predicate {

def children = child

def nullable = true // TODO: Figure out correct nullability semantics of IN.
override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"

override def eval(input: Row): Any = {
hset.contains(value.eval(input))
}
}

case class And(left: Expression, right: Expression) extends BinaryPredicate {
def symbol = "&&"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
Expand All @@ -38,7 +39,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
BooleanSimplification,
SimplifyFilters,
SimplifyCasts,
SimplifyCaseConversionExpressions) ::
SimplifyCaseConversionExpressions,
OptimizedIn) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
PushPredicateThroughProject,
Expand Down Expand Up @@ -225,6 +227,22 @@ object ConstantFolding extends Rule[LogicalPlan] {
}
}

/**
* Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]]
* which is much faster
*/
object OptimizedIn extends Rule[LogicalPlan] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be OptimizeIn.

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if !list.exists(!_.isInstanceOf[Literal])
=> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When pattern matching put the => on the same line and omit the {}.

val hSet = list.map(e => e.eval(null))
InSet(v, HashSet() ++ hSet, v +: list)
}
}
}
}

/**
* Simplifies boolean expressions where the answer can be determined without evaluating both sides.
* Note that this rule can eliminate expressions that might otherwise have been evaluated and thus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import scala.collection.immutable.HashSet

import org.scalatest.FunSuite

import org.apache.spark.sql.catalyst.types._


/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._

Expand Down Expand Up @@ -136,6 +139,24 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
}

test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null
val one = Literal(1)
val two = Literal(2)
val three = Literal(3)
val nl = Literal(null)
val s = Seq(one, two)
val nullS = Seq(one, two, null)
checkEvaluation(InSet(one, hS, one +: s), true)
checkEvaluation(InSet(two, hS, two +: s), true)
checkEvaluation(InSet(two, nS, two +: nullS), true)
checkEvaluation(InSet(nl, nS, nl +: nullS), true)
checkEvaluation(InSet(three, hS, three +: s), false)
checkEvaluation(InSet(three, nS, three +: nullS), false)
checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also double check that there isn't a problem if a null is in the in list. I'm not sure if thats actually valid SQL (and it should never change the result), but we shouldn't throw an exception.

}

test("MaxOf") {
checkEvaluation(MaxOf(1, 2), 2)
checkEvaluation(MaxOf(2, 1), 2)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.types._

// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class OptimizedInSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateAnalysisOperators) ::
Batch("ConstantFolding", Once,
ConstantFolding,
BooleanSimplification,
OptimizedIn) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

test("OptimizedIn test: In clause optimized to InSet") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a test to make sure that in clauses with attributes are not corrupted.

val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2))))
.analyze

val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2,
UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2))))
.analyze

comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: In clause not optimized in case filter has attributes") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b"))))
.analyze

val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b"))))
.analyze

comparePlans(optimized, correctAnswer)
}
}