diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala new file mode 100644 index 000000000000..e7380d21f98a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +/** + * This class is used to compute equality of (sub)expression trees. Expressions can be added + * to this class and they subsequently query for expression equality. Expression trees are + * considered equal if for the same input(s), the same result is produced. + */ +class EquivalentExpressions { + /** + * Wrapper around an Expression that provides semantic equality. + */ + case class Expr(e: Expression) { + val hash = e.semanticHash() + override def equals(o: Any): Boolean = o match { + case other: Expr => e.semanticEquals(other.e) + case _ => false + } + override def hashCode: Int = hash + } + + // For each expression, the set of equivalent expressions. + private val equivalenceMap: mutable.HashMap[Expr, mutable.MutableList[Expression]] = + new mutable.HashMap[Expr, mutable.MutableList[Expression]] + + /** + * Adds each expression to this data structure, grouping them with existing equivalent + * expressions. Non-recursive. + * Returns if there was already a matching expression. + */ + def addExpr(expr: Expression): Boolean = { + if (expr.deterministic) { + val e: Expr = Expr(expr) + val f = equivalenceMap.get(e) + if (f.isDefined) { + f.get.+= (expr) + true + } else { + equivalenceMap.put(e, mutable.MutableList(expr)) + false + } + } else { + false + } + } + + /** + * Adds the expression to this datastructure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + * If ignoreLeaf is true, leaf nodes are ignored. + */ + def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { + val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf + if (!skip && root.deterministic && !addExpr(root)) { + root.children.foreach(addExprTree(_, ignoreLeaf)) + } + } + + /** + * Returns all fo the expression trees that are equivalent to `e`. Returns + * an empty collection if there are none. + */ + def getEquivalentExprs(e: Expression): Seq[Expression] = { + equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList()) + } + + /** + * Returns all the equivalent sets of expressions. + */ + def getAllEquivalentExprs: Seq[Seq[Expression]] = { + equivalenceMap.values.map(_.toSeq).toSeq + } + + /** + * Returns the state of the datastructure as a string. If all is false, skips sets of equivalent + * expressions with cardinality 1. + */ + def debugString(all: Boolean = false): String = { + val sb: mutable.StringBuilder = new StringBuilder() + sb.append("Equivalent expressions:\n") + equivalenceMap.foreach { case (k, v) => { + if (all || v.length > 1) { + sb.append(" " + v.mkString(", ")).append("\n") + } + }} + sb.toString() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 96fcc799e537..7d5741eefcc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = GeneratedExpressionCode("", isNull, primitive) - ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code) + val subExprState = ctx.subExprEliminationExprs.get(this) + if (subExprState.isDefined) { + // This expression is repeated meaning the code to evaluated has already been added + // as a function, `subExprState.fnName`. Just call that. + val code = + s""" + |/* $this */ + |${subExprState.get.fnName}(${ctx.INPUT_ROW}); + |""".stripMargin.trim + GeneratedExpressionCode(code, subExprState.get.code.isNull, subExprState.get.code.value) + } else { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + // Add `this` in the comment. + ve.copy(s"/* $this */\n" + ve.code.trim) + } } /** @@ -145,11 +157,37 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + // Non-determinstic expressions cannot be equal + if (!deterministic || !other.deterministic) return false val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq checkSemantic(elements1, elements2) } + /** + * Returns the hash for this expression. Expressions that compute the same result, even if + * they differ cosmetically should return the same hash. + */ + def semanticHash() : Int = { + def computeHash(e: Seq[Any]): Int = { + // See http://stackoverflow.com/questions/113511/hash-code-implementation + var hash: Int = 17 + e.foreach(i => { + val h: Int = i match { + case (e: Expression) => e.semanticHash() + case (Some(e: Expression)) => e.semanticHash() + case (t: Traversable[_]) => computeHash(t.toSeq) + case null => 0 + case (o) => o.hashCode() + } + hash = hash * 37 + h + }) + hash + } + + computeHash(this.productIterator.toSeq) + } + /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 79dabe8e925a..9f0b7821ae74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -144,6 +144,22 @@ object UnsafeProjection { def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + + /** + * Same as other create()'s but allowing enabling/disabling subexpression elimination. + * TODO: refactor the plumbing and clean this up. + */ + def create( + exprs: Seq[Expression], + inputSchema: Seq[Attribute], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val e = exprs.map(BindReferences.bindReference(_, inputSchema)) + .map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f0f7a6cf0cc4..60a3d6018496 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -92,6 +92,33 @@ class CodeGenContext { addedFunctions += ((funcName, funcCode)) } + /** + * Holds expressions that are equivalent. Used to perform subexpression elimination + * during codegen. + * + * For expressions that appear more than once, generate additional code to prevent + * recomputing the value. + * + * For example, consider two exprsesion generated from this SQL statement: + * SELECT (col1 + col2), (col1 + col2) / col3. + * + * equivalentExpressions will match the tree containing `col1 + col2` and it will only + * be evaluated once. + */ + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + + // State used for subexpression elimination. + case class SubExprEliminationState( + val isLoaded: String, code: GeneratedExpressionCode, val fnName: String) + + // Foreach expression that is participating in subexpression elimination, the state to use. + val subExprEliminationExprs: mutable.HashMap[Expression, SubExprEliminationState] = + mutable.HashMap[Expression, SubExprEliminationState]() + + // The collection of isLoaded variables that need to be reset on each row. + val subExprIsLoadedVariables: mutable.ArrayBuffer[String] = + mutable.ArrayBuffer.empty[String] + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -317,6 +344,87 @@ class CodeGenContext { functions.map(name => s"$name($row);").mkString("\n") } } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpresses, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]) = { + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the exprs that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + commonExprs.foreach(e => { + val expr = e.head + val isLoaded = freshName("isLoaded") + val isNull = freshName("isNull") + val primitive = freshName("primitive") + val fnName = freshName("evalExpr") + + // Generate the code for this expression tree and wrap it in a function. + val code = expr.gen(this) + val fn = + s""" + |private void $fnName(InternalRow ${INPUT_ROW}) { + | if (!$isLoaded) { + | ${code.code.trim} + | $isLoaded = true; + | $isNull = ${code.isNull}; + | $primitive = ${code.value}; + | } + |} + """.stripMargin + code.code = fn + code.isNull = isNull + code.value = primitive + + addNewFunction(fnName, fn) + + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly + // very often. The reason it is not loaded is because of a prior branch. + // 3. Extra store into isLoaded. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + + // Maintain the loaded value and isNull as member variables. This is necessary if the codegen + // function is split across multiple functions. + // TODO: maintaining this as a local variable probably allows the compiler to do better + // optimizations. + addMutableState("boolean", isLoaded, s"$isLoaded = false;") + addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState(javaType(expr.dataType), primitive, + s"$primitive = ${defaultValue(expr.dataType)};") + subExprIsLoadedVariables += isLoaded + + val state = SubExprEliminationState(isLoaded, code, fnName) + e.foreach(subExprEliminationExprs.put(_, state)) + }) + } + + /** + * Generates code for expressions. If doSubexpressionElimination is true, subexpression + * elimination will be performed. Subexpression elimination assumes that the code will for each + * expression will be combined in the `expressions` order. + */ + def generateExpressions(expressions: Seq[Expression], + doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = { + if (doSubexpressionElimination) subexpressionElimination(expressions) + expressions.map(e => e.gen(this)) + } } /** @@ -349,7 +457,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2136f82ba475..9ef226141421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -139,9 +139,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" ${input.code} if (${input.isNull}) { - $setNull + ${setNull.trim} } else { - $writeField + ${writeField.trim} } """ } @@ -149,7 +149,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $rowWriter.initialize($bufferHolder, ${inputs.length}); ${ctx.splitExpressions(row, writeFields)} - """ + """.trim } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -275,8 +275,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { - val exprEvals = expressions.map(e => e.gen(ctx)) + def createCode( + ctx: CodeGenContext, + expressions: Seq[Expression], + useSubexprElimination: Boolean = false): GeneratedExpressionCode = { + val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) val result = ctx.freshName("result") @@ -285,10 +288,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + // Reset the isLoaded flag for each row. + val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n") + val code = s""" $bufferHolder.reset(); + $subexprReset ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} + $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); """ GeneratedExpressionCode(code, "false", result) @@ -300,10 +308,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + create(canonicalize(expressions), subexpressionEliminationEnabled) + } + protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() + create(expressions, false) + } - val eval = createCode(ctx, expressions) + private def create( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val ctx = newCodeGenContext() + val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" public Object generate($exprType[] exprs) { @@ -315,6 +334,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificUnsafeProjection($exprType[] expressions) { @@ -328,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code} + ${eval.code.trim} return ${eval.value}; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9ab5c299d0f5..f80bcfcb0b0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -203,6 +203,10 @@ case class AttributeReference( case _ => false } + override def semanticHash(): Int = { + this.exprId.hashCode() + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala new file mode 100644 index 000000000000..9de066e99d63 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.IntegerType + +class SubexpressionEliminationSuite extends SparkFunSuite { + test("Semantic equals and hash") { + val id = ExprId(1) + val a: AttributeReference = AttributeReference("name", IntegerType)() + val b1 = a.withName("name2").withExprId(id) + val b2 = a.withExprId(id) + + assert(b1 != b2) + assert(a != b1) + assert(b1.semanticEquals(b2)) + assert(!b1.semanticEquals(a)) + assert(a.hashCode != b1.hashCode) + assert(b1.hashCode == b2.hashCode) + assert(b1.semanticHash() == b2.semanticHash()) + } + + test("Expression Equivalence - basic") { + val equivalence = new EquivalentExpressions + assert(equivalence.getAllEquivalentExprs.isEmpty) + + val oneA = Literal(1) + val oneB = Literal(1) + val twoA = Literal(2) + var twoB = Literal(2) + + assert(equivalence.getEquivalentExprs(oneA).isEmpty) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + + // Add oneA and test if it is returned. Since it is a group of one, it does not. + assert(!equivalence.addExpr(oneA)) + assert(equivalence.getEquivalentExprs(oneA).size == 1) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.addExpr((oneA))) + assert(equivalence.getEquivalentExprs(oneA).size == 2) + + // Add B and make sure they can see each other. + assert(equivalence.addExpr(oneB)) + // Use exists and reference equality because of how equals is defined. + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.getAllEquivalentExprs.size == 1) + assert(equivalence.getAllEquivalentExprs.head.size == 3) + assert(equivalence.getAllEquivalentExprs.head.contains(oneA)) + assert(equivalence.getAllEquivalentExprs.head.contains(oneB)) + + val add1 = Add(oneA, oneB) + val add2 = Add(oneA, oneB) + + equivalence.addExpr(add1) + equivalence.addExpr(add2) + + assert(equivalence.getAllEquivalentExprs.size == 2) + assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1)) + assert(equivalence.getEquivalentExprs(add2).size == 2) + assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2)) + } + + test("Expression Equivalence - Trees") { + val one = Literal(1) + val two = Literal(2) + + val add = Add(one, two) + val abs = Abs(add) + val add2 = Add(add, add) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + equivalence.addExprTree(abs, true) + equivalence.addExprTree(add2, true) + + // Should only have one equivalence for `one + two` + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 1) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4) + + // Set up the expressions + // one * two, + // (one * two) * (one * two) + // sqrt( (one * two) * (one * two) ) + // (one * two) + sqrt( (one * two) * (one * two) ) + equivalence = new EquivalentExpressions + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + equivalence.addExprTree(mul, true) + equivalence.addExprTree(mul2, true) + equivalence.addExprTree(sqrt, true) + equivalence.addExprTree(sum, true) + + // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 3) + assert(equivalence.getEquivalentExprs(mul).size == 3) + assert(equivalence.getEquivalentExprs(mul2).size == 3) + assert(equivalence.getEquivalentExprs(sqrt).size == 2) + assert(equivalence.getEquivalentExprs(sum).size == 1) + + // Some expressions inspired by TPCH-Q1 + // sum(l_quantity) as sum_qty, + // sum(l_extendedprice) as sum_base_price, + // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + // avg(l_extendedprice) as avg_price, + // avg(l_discount) as avg_disc + equivalence = new EquivalentExpressions + val quantity = Literal(1) + val price = Literal(1.1) + val discount = Literal(.24) + val tax = Literal(0.1) + equivalence.addExprTree(quantity, false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) + equivalence.addExprTree( + Multiply( + Multiply(price, Subtract(Literal(1), discount)), + Add(Literal(1), tax)), false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(discount, false) + // quantity, price, discount and (price * (1 - discount)) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 4) + } + + test("Expression equivalence - non deterministic") { + val sum = Add(Rand(0), Rand(0)) + val equivalence = new EquivalentExpressions + equivalence.addExpr(sum) + equivalence.addExpr(sum) + assert(equivalence.getAllEquivalentExprs.isEmpty) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ed8b634ad563..f1aa9dcd6abe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -268,6 +268,11 @@ private[spark] object SQLConf { doc = "When true, use the new optimized Tungsten physical execution backend.", isPublic = false) + val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled", + defaultValue = Some(true), // use CODEGEN_ENABLED as default + doc = "When true, common subexpressions will be eliminated.", + isPublic = false) + val DIALECT = stringConf( "spark.sql.dialect", defaultValue = Some("sql"), @@ -532,6 +537,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) + private[spark] def subexpressionEliminationEnabled: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_ENABLED, codegenEnabled) + private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 8bb293ae87e6..8650ac500b65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -66,6 +66,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } else { false } + val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.subexpressionEliminationEnabled + } else { + false + } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 799650a4f784..152eaa1496b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -70,7 +70,8 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitions { iter => - val project = UnsafeProjection.create(projectList, child.output) + val project = UnsafeProjection.create(projectList, child.output, + subexpressionEliminationEnabled) iter.map { row => numRows += 1 project(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3de277a79a52..012089e43fd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2001,4 +2001,52 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) } } + + test("Common subexpression elimination") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + } }