From ffe4ca8b41416d82677dae72e3e74faeb0f721f8 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Fri, 24 Feb 2023 14:49:25 +0800 Subject: [PATCH 01/10] [SPARK-42551][SQL] Support more subexpression elimination cases --- .../expressions/EquivalentExpressions.scala | 113 ++----- .../sql/catalyst/expressions/Expression.scala | 38 ++- .../expressions/codegen/CodeGenerator.scala | 280 ++++-------------- .../codegen/GenerateMutableProjection.scala | 3 +- .../codegen/GeneratePredicate.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 3 +- .../expressions/CodeGenerationSuite.scala | 61 ---- .../SubexpressionEliminationSuite.scala | 140 ++------- .../spark/sql/execution/ExpandExec.scala | 2 + .../sql/execution/WholeStageCodegenExec.scala | 3 + .../aggregate/AggregateCodegenSupport.scala | 109 +++---- .../aggregate/HashAggregateExec.scala | 25 +- .../execution/basicPhysicalOperators.scala | 18 +- .../execution/joins/JoinCodegenSupport.scala | 6 +- .../execution/WholeStageCodegenSuite.scala | 22 -- 15 files changed, 209 insertions(+), 618 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 1a84859cc3a15..dfb5d95238c9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -21,9 +21,8 @@ import java.util.Objects import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.supportedExpression import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils /** @@ -31,9 +30,7 @@ import org.apache.spark.util.Utils * to this class and they subsequently query for expression equality. Expression trees are * considered equal if for the same input(s), the same result is produced. */ -class EquivalentExpressions( - skipForShortcutEnable: Boolean = SQLConf.get.subexpressionEliminationSkipForShotcutExpr) { - +class EquivalentExpressions { // For each expression, the set of equivalent expressions. private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] @@ -91,92 +88,6 @@ class EquivalentExpressions( } } - /** - * Adds or removes only expressions which are common in each of given expressions, in a recursive - * way. - * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common - * expression `(c + 1)` will be added into `equivalenceMap`. - * - * Note that as we don't know in advance if any child node of an expression will be common across - * all given expressions, we compute local equivalence maps for all given expressions and filter - * only the common nodes. - * Those common nodes are then removed from the local map and added to the final map of - * expressions. - */ - private def updateCommonExprs( - exprs: Seq[Expression], - map: mutable.HashMap[ExpressionEquals, ExpressionStats], - useCount: Int): Unit = { - assert(exprs.length > 1) - var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] - updateExprTree(exprs.head, localEquivalenceMap) - - exprs.tail.foreach { expr => - val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] - updateExprTree(expr, otherLocalEquivalenceMap) - localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => - otherLocalEquivalenceMap.contains(key) - } - } - - // Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`. - // The remaining highest expression in `localEquivalenceMap` is also common expression so loop - // until `localEquivalenceMap` is not empty. - var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) - while (statsOption.nonEmpty) { - val stats = statsOption.get - updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount) - updateExprTree(stats.expr, map, useCount) - - statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) - } - } - - private def skipForShortcut(expr: Expression): Expression = { - if (skipForShortcutEnable) { - // The subexpression may not need to eval even if it appears more than once. - // e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true. - expr match { - case and: And => and.left - case or: Or => or.left - case other => other - } - } else { - expr - } - } - - // There are some special expressions that we should not recurse into all of its children. - // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. ConditionalExpression: use its children that will always be evaluated. - private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { - case _: CodegenFallback => Nil - case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) - case other => skipForShortcut(other).children - } - - // For some special expressions we cannot just recurse into all of its children, but we can - // recursively add the common expressions shared between all of its children. - private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { - case _: CodegenFallback => Nil - case c: ConditionalExpression => c.branchGroups - case _ => Nil - } - - private def supportedExpression(e: Expression) = { - !e.exists { - // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the - // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - case _: LambdaVariable => true - - // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, - // can cause error like NPE. - case _: PlanExpression[_] => Utils.isInRunningSparkTask - - case _ => false - } - } - /** * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. @@ -197,8 +108,7 @@ class EquivalentExpressions( if (!skip && !updateExprInMap(expr, map, useCount)) { val uc = useCount.signum - childrenToRecurse(expr).foreach(updateExprTree(_, map, uc)) - commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc)) + expr.children.foreach(updateExprTree(_, map, uc)) } } @@ -240,6 +150,23 @@ class EquivalentExpressions( } } +object EquivalentExpressions { + def supportedExpression(e: Expression): Boolean = { + !e.exists { + // `LambdaVariable` is usually used as a loop variable and `NamedLambdaVariable` is used in + // higher-order functions, which can't be evaluated ahead of the execution. + case _: LambdaVariable => true + case _: NamedLambdaVariable => true + + // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, + // can cause error like NPE. + case _: PlanExpression[_] => Utils.isInRunningSparkTask + + case _ => false + } + } +} + /** * Wrapper around an Expression that provides semantic equality. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c2330cdb59dbc..1f080182d17eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -196,9 +196,41 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) + val exprKey = ExpressionEquals(this) + val eval = if (EquivalentExpressions.supportedExpression(this)) { + ctx.commonExpressions.get(exprKey) match { + case Some((useCount, genFunc, Some(reuseExprCode))) => + ctx.commonExpressions -= exprKey + if (useCount <= 1) { + ctx.commonExpressions -= exprKey + } else { + ctx.commonExpressions += exprKey -> + (useCount - 1, genFunc, Some(reuseExprCode)) + } + reuseExprCode + case Some((useCount, genFunc, None)) => + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) + val reuseExprCode = genFunc(eval) + ctx.commonExpressions -= exprKey + if (useCount <= 1) { + ctx.commonExpressions -= exprKey + } else { + ctx.commonExpressions += exprKey -> + (useCount - 1, genFunc, Some(reuseExprCode)) + } + reuseExprCode + case None => + doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) + } + } else { + doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) + } reduceCodeSize(ctx, eval) if (eval.code.toString.nonEmpty) { // Add `this` in the comment. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5651a30515f28..1959d405ae92c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1027,131 +1027,24 @@ class CodegenContext extends Logging { splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) } - /** - * Perform a function which generates a sequence of ExprCodes with a given mapping between - * expressions and common expressions, instead of using the mapping in current context. - */ - def withSubExprEliminationExprs( - newSubExprEliminationExprs: Map[ExpressionEquals, SubExprEliminationState])( - f: => Seq[ExprCode]): Seq[ExprCode] = { - val oldsubExprEliminationExprs = subExprEliminationExprs - subExprEliminationExprs = newSubExprEliminationExprs - - val genCodes = f - - // Restore previous subExprEliminationExprs - subExprEliminationExprs = oldsubExprEliminationExprs - genCodes - } - - /** - * Evaluates a sequence of `SubExprEliminationState` which represent subexpressions. After - * evaluating a subexpression, this method will clean up the code block to avoid duplicate - * evaluation. - */ - def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = { - val code = new StringBuilder() - - subExprStates.foreach { state => - val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code - code.append(currentCode + "\n") - state.eval.code = EmptyBlock - } - - code.toString() - } - - /** - * Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen. - * - * This finds the common subexpressions, generates the code snippets that evaluate those - * expressions and populates the mapping of common subexpressions to the generated code snippets. - * - * The generated code snippet for subexpression is wrapped in `SubExprEliminationState`, which - * contains an `ExprCode` and the children `SubExprEliminationState` if any. The `ExprCode` - * includes java source code, result variable name and is-null variable name of the subexpression. - * - * Besides, this also returns a sequences of `ExprCode` which are expression codes that need to - * be evaluated (as their input parameters) before evaluating subexpressions. - * - * To evaluate the returned subexpressions, please call `evaluateSubExprEliminationState` with - * the `SubExprEliminationState`s to be evaluated. During generating the code, it will cleanup - * the states to avoid duplicate evaluation. - * - * The details of subexpression generation: - * 1. Gets subexpression set. See `EquivalentExpressions`. - * 2. Generate code of subexpressions as a whole block of code (non-split case) - * 3. Check if the total length of the above block is larger than the split-threshold. If so, - * try to split it in step 4, otherwise returning the non-split code block. - * 4. Check if parameter lengths of all subexpressions satisfy the JVM limitation, if so, - * try to split, otherwise returning the non-split code block. - * 5. For each subexpression, generating a function and put the code into it. To evaluate the - * subexpression, just call the function. - * - * The explanation of subexpression codegen: - * 1. Wrapping in `withSubExprEliminationExprs` call with current subexpression map. Each - * subexpression may depends on other subexpressions (children). So when generating code - * for subexpressions, we iterate over each subexpression and put the mapping between - * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression - * evaluation, we can look for generated subexpressions and do replacement. - */ - def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { - // Create a clear EquivalentExpressions and SubExprEliminationState mapping - val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val localSubExprEliminationExprsForNonSplit = - mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] - - // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_)) - - // Get all the expressions that appear at least twice and set up the state for subexpression - // elimination. - val commonExprs = equivalentExpressions.getCommonSubexpressions - - val nonSplitCode = { - val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] - commonExprs.map { expr => - withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { - val eval = expr.genCode(this) - // Collects other subexpressions from the children. - val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] - expr.foreach { e => - subExprEliminationExprs.get(ExpressionEquals(e)) match { - case Some(state) => childrenSubExprs += state - case _ => - } + def subexpressionElimination(expressions: Expression*): Block = { + var initBlock: Block = EmptyBlock + if (SQLConf.get.subexpressionEliminationEnabled) { + // Create a clear EquivalentExpressions and SubExprEliminationState mapping + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + // Add current expression tree and compute the common subexpressions. + expressions.map(equivalentExpressions.addExprTree(_)) + + commonExpressions.clear() + commonExpressions ++= equivalentExpressions.getAllExprStates(1).map { stats => + val expr = stats.expr + val initialized = addMutableState(JAVA_BOOLEAN, "subExprInit") + initBlock += code"$initialized = false;\n" + val wrapperFunc: ExprCode => ExprCode = { eval => + val (inputVars, exprCodes) = { + val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr) + (inputVars.toSeq, exprCodes.toSeq) } - val state = SubExprEliminationState(eval, childrenSubExprs.toSeq) - localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state) - allStates += state - Seq(eval) - } - } - allStates.toSeq - } - - // For some operators, they do not require all its child's outputs to be evaluated in advance. - // Instead it only early evaluates part of outputs, for example, `ProjectExec` only early - // evaluate the outputs used more than twice. So we need to extract these variables used by - // subexpressions and evaluate them before subexpressions. - val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr => - val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr) - (inputVars.toSeq, exprCodes.toSeq) - }.unzip - - val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold - val (subExprsMap, exprCodes) = if (needSplit) { - if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { - val localSubExprEliminationExprs = - mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] - - commonExprs.zipWithIndex.foreach { case (expr, i) => - val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { - Seq(expr.genCode(this)) - }.head - - val value = addMutableState(javaType(expr.dataType), "subExprValue") - val isNullLiteral = eval.isNull match { case TrueLiteral | FalseLiteral => true case _ => false @@ -1162,105 +1055,41 @@ class CodegenContext extends Logging { } else { (eval.isNull, "") } - - // Generate the code for this expression tree and wrap it in a function. - val fnName = freshName("subExpr") - val inputVars = inputVarsForAllFuncs(i) - val argList = - inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") - val fn = - s""" - |private void $fnName(${argList.mkString(", ")}) { - | ${eval.code} - | $isNullEvalCode - | $value = ${eval.value}; - |} + val value = addMutableState(javaType(expr.dataType), "subExprValue") + val code = if (isValidParamLength(calculateParamLengthFromExprValues(inputVars))) { + // Generate the code for this expression tree and wrap it in a function. + val fnName = freshName("subExpr") + val argList = + inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") + val fn = + s""" + |private void $fnName(${argList.mkString(", ")}) { + | if (!$initialized) { + | ${eval.code} + | $initialized = true; + | $isNullEvalCode + | $value = ${eval.value}; + | } + |} + """.stripMargin + val inputVariables = inputVars.map(_.variableName).mkString(", ") + code"${addNewFunction(fnName, fn)}($inputVariables);" + } else { + code""" + |if (!$initialized) { + | ${eval.code} + | $initialized = true; + | $isNullEvalCode + | $value = ${eval.value}; + |} """.stripMargin - - // Collects other subexpressions from the children. - val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] - expr.foreach { e => - localSubExprEliminationExprs.get(ExpressionEquals(e)) match { - case Some(state) => childrenSubExprs += state - case _ => - } } - - val inputVariables = inputVars.map(_.variableName).mkString(", ") - val code = code"${addNewFunction(fnName, fn)}($inputVariables);" - val state = SubExprEliminationState( - ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), - childrenSubExprs.toSeq) - localSubExprEliminationExprs.put(ExpressionEquals(expr), state) + ExprCode(code, isNull, JavaCode.global(value, expr.dataType)) } - (localSubExprEliminationExprs, exprCodesNeedEvaluate) - } else { - val errMsg = "Failed to split subexpression code into small functions because " + - "the parameter length of at least one split function went over the JVM limit: " + - MAX_JVM_METHOD_PARAMS_LENGTH - if (Utils.isTesting) { - throw new IllegalStateException(errMsg) - } else { - logInfo(errMsg) - (localSubExprEliminationExprsForNonSplit, Seq.empty) - } - } - } else { - (localSubExprEliminationExprsForNonSplit, Seq.empty) - } - SubExprCodes(subExprsMap.toMap, exprCodes.flatten) - } - - /** - * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpressions, generates the functions that evaluate those expressions and populates - * the mapping of common subexpressions to the generated functions. - */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { - // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_)) - - // Get all the expressions that appear at least twice and set up the state for subexpression - // elimination. - val commonExprs = equivalentExpressions.getCommonSubexpressions - commonExprs.foreach { expr => - val fnName = freshName("subExpr") - val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - // Generate the code for this expression tree and wrap it in a function. - val eval = expr.genCode(this) - val fn = - s""" - |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code} - | $isNull = ${eval.isNull}; - | $value = ${eval.value}; - |} - """.stripMargin - - // Add a state and a mapping of the common subexpressions that are associate with this - // state. Adding this expression to subExprEliminationExprMap means it will call `fn` - // when it is code generated. This decision should be a cost based one. - // - // The cost of doing subexpression elimination is: - // 1. Extra function call, although this is probably *good* as the JIT can decide to - // inline or not. - // The benefit doing subexpression elimination is: - // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 - // above. - // 2. Less code. - // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with - // at least two nodes) as the cost of doing it is expected to be low. - - val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - subexprFunctions += subExprCode - val state = SubExprEliminationState( - ExprCode(code"$subExprCode", - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType))) - subExprEliminationExprs += ExpressionEquals(expr) -> state + ExpressionEquals(expr) -> (stats.useCount, wrapperFunc, None) + }.toMap } + initBlock } /** @@ -1270,12 +1099,16 @@ class CodegenContext extends Logging { */ def generateExpressions( expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { + doSubexpressionElimination: Boolean = false): (Seq[ExprCode], Block) = { // We need to make sure that we do not reuse stateful expressions. This is needed for codegen // as well because some expressions may implement `CodegenFallback`. val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression()) - if (doSubexpressionElimination) subexpressionElimination(cleanedExpressions) - cleanedExpressions.map(e => e.genCode(this)) + val initBlock = if (doSubexpressionElimination) { + subexpressionElimination(cleanedExpressions: _*) + } else { + EmptyBlock + } + (cleanedExpressions.map(e => e.genCode(this)), initBlock) } /** @@ -1314,6 +1147,9 @@ class CodegenContext extends Logging { EmptyBlock } } + + private[spark] val commonExpressions = + new mutable.HashMap[ExpressionEquals, (Int, ExprCode => ExprCode, Option[ExprCode])] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 2e018de07101e..4829348aacf39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -61,7 +61,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case (NoOp, _) => false case _ => true } - val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) + val (exprVals, initBlock) = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { @@ -130,6 +130,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $initBlock $evalSubexpr $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c246d07f189b4..ab2f6f47e847a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,7 +38,8 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { val ctx = newCodeGenContext() // Do sub-expression elimination for predicates. - val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head + val (evals, initBlock) = ctx.generateExpressions(Seq(predicate), useSubexprElimination) + val eval = evals.head val evalSubexpr = ctx.subexprFunctionsCode val codeBody = s""" @@ -60,6 +61,7 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { } public boolean eval(InternalRow ${ctx.INPUT_ROW}) { + $initBlock $evalSubexpr ${eval.code} return !${eval.isNull} && ${eval.value}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 459c1d9a8ba11..68bed063ab61e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { - val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) + val (exprEvals, initBlock) = ctx.generateExpressions(expressions, useSubexprElimination) val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) val numVarLenFields = exprSchemas.count { @@ -307,6 +307,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = code""" + |$initBlock |$rowWriter.reset(); |$evalSubexpr |$writeExpressions diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 265b0eeb8bdf8..7e73becbe0dfe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -463,55 +463,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { private def wrap(expr: Expression): ExpressionEquals = ExpressionEquals(expr) - test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { - - val ref = BoundReference(0, IntegerType, true) - val add1 = Add(ref, ref) - val add2 = Add(add1, add1) - val dummy = SubExprEliminationState( - ExprCode(EmptyBlock, - JavaCode.variable("dummy", BooleanType), - JavaCode.variable("dummy", BooleanType))) - - // raw testing of basic functionality - { - val ctx = new CodegenContext - val e = ref.genCode(ctx) - // before - ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState( - ExprCode(EmptyBlock, e.isNull, e.value)) - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) - } - - // emulate an actual codegen workload - { - val ctx = new CodegenContext - // before - ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(wrap(ref) -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) - } - } - test("SPARK-23986: freshName can generate duplicated names") { val ctx = new CodegenContext val names1 = ctx.freshName("myName1") :: ctx.freshName("myName1") :: @@ -536,18 +487,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { .exists(_.getMessage().getFormattedMessage.contains("Generated method too long"))) } - test("SPARK-28916: subexpression elimination can cause 64kb code limit on UnsafeProjection") { - val numOfExprs = 10000 - val exprs = (0 to numOfExprs).flatMap(colIndex => - Seq(Add(BoundReference(colIndex, DoubleType, true), - BoundReference(numOfExprs + colIndex, DoubleType, true)), - Add(BoundReference(colIndex, DoubleType, true), - BoundReference(numOfExprs + colIndex, DoubleType, true)))) - // these should not fail to compile due to 64K limit - GenerateUnsafeProjection.generate(exprs, true) - GenerateMutableProjection.generate(exprs, true) - } - test("SPARK-32624: Use CodeGenerator.typeName() to fix byte[] compile issue") { val ctx = new CodegenContext val bytes = new Array[Byte](3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index f369635a32671..0566972b29005 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -22,8 +22,7 @@ import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType} +import org.apache.spark.sql.types.{DataType, IntegerType, ObjectType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -147,15 +146,15 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence = new EquivalentExpressions equivalence.addExprTree(add) // the `two` inside `fallback` should not be added - assert(equivalence.getAllExprStates(1).size == 0) - assert(equivalence.getAllExprStates().count(_.useCount == 1) == 3) // add, two, explode + assert(equivalence.getAllExprStates(1).size == 1) + assert(equivalence.getAllExprStates().count(_.useCount == 1) == 2) // add, two, explode } test("Children of conditional expressions: If") { val add = Add(Literal(1), Literal(2)) val condition = GreaterThan(add, Literal(3)) - val ifExpr1 = If(condition, add, add) + val ifExpr1 = If(condition, add, Literal(1)) val equivalence1 = new EquivalentExpressions equivalence1.addExprTree(ifExpr1) @@ -172,28 +171,27 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(ifExpr2) - assert(equivalence2.getAllExprStates(1).isEmpty) - assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 3) + assert(equivalence2.getAllExprStates(1).nonEmpty) + assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 4) val ifExpr3 = If(condition, ifExpr1, ifExpr1) val equivalence3 = new EquivalentExpressions equivalence3.addExprTree(ifExpr3) // `add`: 2, `condition`: 2 - assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 2) + assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 3) assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq condition)) assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq add)) // `ifExpr1`, `ifExpr3` - assert(equivalence3.getAllExprStates().count(_.useCount == 1) == 2) - assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1)) + assert(equivalence3.getAllExprStates().count(_.useCount == 1) == 1) assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr3)) } test("Children of conditional expressions: CaseWhen") { val add1 = Add(Literal(1), Literal(2)) val add2 = Add(Literal(2), Literal(3)) - val conditions1 = (GreaterThan(add2, Literal(3)), add1) :: + val conditions1 = (GreaterThan(add1, Literal(3)), add1) :: (GreaterThan(add2, Literal(4)), add1) :: (GreaterThan(add2, Literal(5)), add1) :: Nil @@ -215,7 +213,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel // `add1` is repeatedly in all branch values, and first predicate. assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add1) + assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values. val conditions3 = (GreaterThan(add1, Literal(3)), add2) :: @@ -240,8 +238,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(coalesceExpr1) // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 0) // Negative case. `add1` and `add2` both are not used in all branches. val conditions2 = GreaterThan(add1, Literal(3)) :: @@ -252,62 +249,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(coalesceExpr2) - assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 0) - } - - test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") { - withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1") { - val str = BoundReference(0, BinaryType, false) - val pos = BoundReference(1, IntegerType, false) - - val substr = new Substring(str, pos) - - val add = Add(Length(substr), Literal(1)) - val add2 = Add(Length(substr), Literal(2)) - - val ctx = new CodegenContext() - val exprs = Seq(add, add2) - - val oneVar = ctx.freshVariable("str", BinaryType) - val twoVar = ctx.freshVariable("pos", IntegerType) - ctx.addMutableState("byte[]", oneVar, forceInline = true, useFreshName = false) - ctx.addMutableState("int", twoVar, useFreshName = false) - - ctx.INPUT_ROW = null - ctx.currentVars = Seq( - ExprCode(TrueLiteral, oneVar), - ExprCode(TrueLiteral, twoVar)) - - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) - ctx.withSubExprEliminationExprs(subExprs.states) { - exprs.map(_.genCode(ctx)) - } - val subExprsCode = ctx.evaluateSubExprEliminationState(subExprs.states.values) - - val codeBody = s""" - public java.lang.Object generate(Object[] references) { - return new TestCode(references); - } - - class TestCode { - ${ctx.declareMutableStates()} - - public TestCode(Object[] references) { - } - - public void initialize(int partitionIndex) { - ${subExprsCode} - } - - ${ctx.declareAddedFunctions()} - } - """ - - val code = CodeFormatter.stripOverlappingComments( - new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) - - CodeGenerator.compile(code) - } + assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1) } test("SPARK-35410: SubExpr elimination should not include redundant child exprs " + @@ -323,7 +265,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val commonExprs = equivalence.getAllExprStates(1) assert(commonExprs.size == 1) - assert(commonExprs.head.useCount == 2) + assert(commonExprs.head.useCount == 3) assert(commonExprs.head.expr eq add3) } @@ -337,8 +279,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(ifExpr3) val commonExprs = equivalence.getAllExprStates(1) - assert(commonExprs.size == 1) - assert(commonExprs.head.useCount == 2) + assert(commonExprs.size == 2) + assert(commonExprs.head.useCount == 4) assert(commonExprs.head.expr eq add) } @@ -399,28 +341,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(caseWhenExpr) // `add1` is not in the elseValue, so we can't extract it from the branches - assert(equivalence.getAllExprStates().count(_.useCount == 2) == 0) - } - - test("SPARK-35829: SubExprEliminationState keeps children sub exprs") { - val add1 = Add(Literal(1), Literal(2)) - val add2 = Add(add1, add1) - - val exprs = Seq(add1, add1, add2, add2) - val ctx = new CodegenContext() - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) - - val add2State = subExprs.states(ExpressionEquals(add2)) - val add1State = subExprs.states(ExpressionEquals(add1)) - assert(add2State.children.contains(add1State)) - - subExprs.states.values.foreach { state => - assert(state.eval.code != EmptyBlock) - } - ctx.evaluateSubExprEliminationState(subExprs.states.values) - subExprs.states.values.foreach { state => - assert(state.eval.code == EmptyBlock) - } + assert(equivalence.getAllExprStates().count(_.useCount == 2) == 1) } test("SPARK-38333: PlanExpression expression should skip addExprTree function in Executor") { @@ -443,7 +364,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val n1 = NaNvl(Literal(1.0d), Add(add, add)) val e1 = new EquivalentExpressions e1.addExprTree(n1) - assert(e1.getCommonSubexpressions.isEmpty) + assert(e1.getCommonSubexpressions.nonEmpty) val n2 = NaNvl(add, add) val e2 = new EquivalentExpressions @@ -467,33 +388,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val cseState = equivalence.getExprState(expr) assert(hasMatching == cseState.isDefined) } - - test("SPARK-42815: Subexpression elimination support shortcut conditional expression") { - val add = Add(Literal(1), Literal(0)) - val equal = EqualTo(add, add) - - def checkShortcut(expr: Expression, numCommonExpr: Int): Unit = { - val e1 = If(expr, Literal(1), Literal(2)) - val ee1 = new EquivalentExpressions(true) - ee1.addExprTree(e1) - assert(ee1.getCommonSubexpressions.size == numCommonExpr) - - val e2 = expr - val ee2 = new EquivalentExpressions(true) - ee2.addExprTree(e2) - assert(ee2.getCommonSubexpressions.size == numCommonExpr) - } - - // shortcut right child - checkShortcut(And(Literal(false), equal), 0) - checkShortcut(Or(Literal(true), equal), 0) - checkShortcut(Not(And(Literal(true), equal)), 0) - - // always eliminate subexpression for left child - checkShortcut((And(equal, Literal(false))), 1) - checkShortcut(Or(equal, Literal(true)), 1) - checkShortcut(Not(And(equal, Literal(false))), 1) - } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index c087fdf5f962b..d362ac48a1c28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -168,6 +168,8 @@ case class ExpandExec( } // Part 2: switch/case statements + initBlock += ctx.subexpressionElimination( + projections.flatten.map(BindReferences.bindReference(_, attributeSeq)): _*) val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) => val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => if (!sameOutput(col)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index ddc2cfb56d4f6..1b3fe92adea2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -198,6 +198,7 @@ trait CodegenSupport extends SparkPlan { s""" |${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")} |$evaluated + |${parent.initBlock} |$consumeFunc """.stripMargin } @@ -342,6 +343,8 @@ trait CodegenSupport extends SparkPlan { throw new UnsupportedOperationException } + var initBlock: Block = EmptyBlock + /** * Whether or not the result rows of this operator should be copied before putting into a buffer. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 1377a98422317..25ea1b71d279b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, ExpressionEquals, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -207,12 +207,9 @@ trait AggregateCodegenSupport val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten: _*) val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) - } + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } val aggNames = functions.map(_.prettyName) @@ -236,11 +233,11 @@ trait AggregateCodegenSupport } val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( - ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks) s""" |// do aggregate |// common sub-expressions - |$effectiveCodes + |$initBlock |// evaluate aggregate functions and update aggregation buffers |$codeToEvalAggFuncs """.stripMargin @@ -255,19 +252,21 @@ trait AggregateCodegenSupport inputAttrs: Seq[Attribute], boundUpdateExprs: Seq[Seq[Expression]], aggNames: Seq[String], - aggCodeBlocks: Seq[Block], - subExprs: SubExprCodes): String = { + aggCodeBlocks: Seq[Block]): String = { + val evaluated = boundUpdateExprs.flatten.map { e => + evaluateRequiredVariables(inputAttrs, input, e.references) + }.mkString("", "\n", "\n") val aggCodes = if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val maybeSplitCodes = splitAggregateExpressions( - ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) + ctx, aggNames, boundUpdateExprs, aggCodeBlocks) maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code)) } else { aggCodeBlocks.map(_.code) } - aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map { + evaluated + aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map { case (aggCode, (Partial | Complete, Some(condition))) => // Note: wrap in "do { } while(false);", so the generated checks can jump out // with "continue;" @@ -295,59 +294,49 @@ trait AggregateCodegenSupport ctx: CodegenContext, aggNames: Seq[String], aggBufferUpdatingExprs: Seq[Seq[Expression]], - aggCodeBlocks: Seq[Block], - subExprs: Map[ExpressionEquals, SubExprEliminationState]): Option[Seq[String]] = { - val exprValsInSubExprs = subExprs.flatMap { case (_, s) => - s.eval.value :: s.eval.isNull :: Nil - } - if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { - // `SimpleExprValue`s cannot be used as an input variable for split functions, so - // we give up splitting functions if it exists in `subExprs`. - None - } else { - val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => - val inputVarsForOneFunc = aggExprsForOneFunc.map( - CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1).reduce(_ ++ _).toSeq - val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) + aggCodeBlocks: Seq[Block]): Option[Seq[String]] = { + val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => + val inputVarsForOneFunc = aggExprsForOneFunc.map( + CodeGenerator.getLocalInputVariableValues(ctx, _)._1).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) - // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit - if (CodeGenerator.isValidParamLength(paramLength)) { - Some(inputVarsForOneFunc) - } else { - None - } + // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit + if (CodeGenerator.isValidParamLength(paramLength)) { + Some(inputVarsForOneFunc) + } else { + None } + } - // Checks if all the aggregate code can be split into pieces. - // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, - // we totally give up splitting aggregate code. - if (inputVars.forall(_.isDefined)) { - val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => - val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") - val argList = args.map { v => - s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" - }.mkString(", ") - val doAggFuncName = ctx.addNewFunction(doAggFunc, - s""" - |private void $doAggFunc($argList) throws java.io.IOException { - | ${aggCodeBlocks(i)} - |} - """.stripMargin) + // Checks if all the aggregate code can be split into pieces. + // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, + // we totally give up splitting aggregate code. + if (inputVars.forall(_.isDefined)) { + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => + val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") + val argList = args.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + }.mkString(", ") + val doAggFuncName = ctx.addNewFunction(doAggFunc, + s""" + |private void $doAggFunc($argList) throws java.io.IOException { + | ${aggCodeBlocks(i)} + |} + """.stripMargin) - val inputVariables = args.map(_.variableName).mkString(", ") - s"$doAggFuncName($inputVariables);" - } - Some(splitCodes) + val inputVariables = args.map(_.variableName).mkString(", ") + s"$doAggFuncName($inputVariables);" + } + Some(splitCodes) + } else { + val errMsg = "Failed to split aggregate code into small functions because the parameter " + + "length of at least one split function went over the JVM limit: " + + CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) } else { - val errMsg = "Failed to split aggregate code into small functions because the parameter " + - "length of at least one split function went over the JVM limit: " + - CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH - if (Utils.isTesting) { - throw new IllegalStateException(errMsg) - } else { - logInfo(errMsg) - None - } + logInfo(errMsg) + None } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6c83ba5546d2a..a47239865c6c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -625,7 +625,7 @@ case class HashAggregateExec( // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, bindReferences[Expression](groupingExpressions, child.output)) - val fastRowKeys = ctx.generateExpressions( + val (fastRowKeys, initBlock) = ctx.generateExpressions( bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") @@ -688,6 +688,7 @@ case class HashAggregateExec( // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" + |$initBlock |${fastRowKeys.map(_.code).mkString("\n")} |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { | $fastRowBuffer = $fastHashMapTerm.findOrInsert( @@ -727,12 +728,9 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten: _*) val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) - } + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } val aggCodeBlocks = updateExprs.indices.map { i => @@ -757,10 +755,10 @@ case class HashAggregateExec( } val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( - ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks) s""" |// common sub-expressions - |$effectiveCodes + |$initBlock |// evaluate aggregate functions and update aggregation buffers |$codeToEvalAggFuncs """.stripMargin @@ -773,12 +771,9 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten: _*) val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) - } + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) => @@ -802,7 +797,7 @@ case class HashAggregateExec( } val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( - ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks) // If vectorized fast hash map is on, we first generate code to update row // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. @@ -810,7 +805,7 @@ case class HashAggregateExec( s""" |if ($fastRowBuffer != null) { | // common sub-expressions - | $effectiveCodes + | $initBlock | // evaluate aggregate functions and update aggregation buffers | $codeToEvalAggFuncs |} else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 68f056d894b9f..2c9b61cebcfe7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -67,24 +67,12 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) - val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { - // subexpression elimination - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) - val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { - exprs.map(_.genCode(ctx)) - } - (ctx.evaluateSubExprEliminationState(subExprs.states.values), genVars, - subExprs.exprCodesNeedEvaluate) - } else { - ("", exprs.map(_.genCode(ctx)), Seq.empty) - } + initBlock += ctx.subexpressionElimination(exprs: _*) + val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" - |// common sub-expressions - |${evaluateVariables(localValInputs)} - |$subExprsCode |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} |${consume(ctx, resultVars)} """.stripMargin @@ -178,6 +166,8 @@ trait GeneratePredicateHelper extends PredicateHelper { // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() + initBlock += + ctx.subexpressionElimination(otherPreds.map(BindReferences.bindReference(_, inputAttrs)): _*) val generated = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala index a7d1edefcd611..4a11a7b9709dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala @@ -55,11 +55,13 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { // filter the output via condition ctx.currentVars = streamVars2 ++ buildVars - val ev = - BindReferences.bindReference(expr, streamPlan.output ++ buildPlan.output).genCode(ctx) + val bondExpr = BindReferences.bindReference(expr, streamPlan.output ++ buildPlan.output) + initBlock += ctx.subexpressionElimination(bondExpr) + val ev = bondExpr.genCode(ctx) val skipRow = s"${ev.isNull} || !${ev.value}" s""" |$eval + |$initBlock |${ev.code} |if (!($skipRow)) """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ac710c3229647..72ed5d5b70e4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -770,26 +770,4 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } } - - test("Give up splitting subexpression code if a parameter length goes over the limit") { - withSQLConf( - SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false", - SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", - "spark.sql.CodeGenerator.validParamLength" -> "0") { - withTable("t") { - val expectedErrMsg = "Failed to split subexpression code into small functions" - Seq( - // Test case without keys - "SELECT AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1)) t(a, b, c)", - // Tet case with keys - "SELECT k, AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1, 1)) t(k, a, b, c) " + - "GROUP BY k").foreach { query => - val e = intercept[IllegalStateException] { - sql(query).collect - } - assert(e.getMessage.contains(expectedErrMsg)) - } - } - } - } } From db77ea2920fe8f0d7361b9c09b72700335ffa4eb Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Sat, 8 Apr 2023 20:44:29 +0800 Subject: [PATCH 02/10] Support whole stage subexpressions elimination --- .../expressions/EquivalentExpressions.scala | 10 +- .../sql/catalyst/expressions/Expression.scala | 55 +++---- .../expressions/codegen/CodeGenerator.scala | 144 ++++++++++-------- .../spark/sql/execution/ExpandExec.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 52 ++++++- .../aggregate/AggregateCodegenSupport.scala | 4 +- .../aggregate/HashAggregateExec.scala | 7 +- .../execution/basicPhysicalOperators.scala | 10 +- .../joins/BroadcastNestedLoopJoinExec.scala | 3 + .../spark/sql/execution/joins/HashJoin.scala | 2 + .../execution/joins/JoinCodegenSupport.scala | 4 +- 11 files changed, 180 insertions(+), 113 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index dfb5d95238c9c..d8f7368038a12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -22,6 +22,7 @@ import java.util.Objects import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.supportedExpression +import org.apache.spark.sql.catalyst.expressions.codegen.ExprValue import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable import org.apache.spark.util.Utils @@ -194,4 +195,11 @@ case class ExpressionEquals(e: Expression) { * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" * useCount in this wrapper in-place. */ -case class ExpressionStats(expr: Expression)(var useCount: Int) +case class ExpressionStats(expr: Expression)( + var useCount: Int, + var initialized: Option[String] = None, + var isNull: Option[ExprValue] = None, + var value: Option[ExprValue] = None, + var funcName: Option[String] = None, + var params: Option[Seq[Class[_]]] = None, + var addedFunction: Boolean = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1f080182d17eb..28b2569d9067a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -197,39 +197,28 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val exprKey = ExpressionEquals(this) - val eval = if (EquivalentExpressions.supportedExpression(this)) { - ctx.commonExpressions.get(exprKey) match { - case Some((useCount, genFunc, Some(reuseExprCode))) => - ctx.commonExpressions -= exprKey - if (useCount <= 1) { - ctx.commonExpressions -= exprKey - } else { - ctx.commonExpressions += exprKey -> - (useCount - 1, genFunc, Some(reuseExprCode)) - } - reuseExprCode - case Some((useCount, genFunc, None)) => - val eval = doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) - val reuseExprCode = genFunc(eval) - ctx.commonExpressions -= exprKey - if (useCount <= 1) { - ctx.commonExpressions -= exprKey - } else { - ctx.commonExpressions += exprKey -> - (useCount - 1, genFunc, Some(reuseExprCode)) - } - reuseExprCode - case None => - doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) - } - } else { - doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) + val eval = ctx.commonExpressions.get(exprKey) match { + case Some(stats) => + // We should reuse the currentVar references which code is not empty + val nonEmptyRefs = this.exists { + case BoundReference(ordinal, _, _) => + ctx.currentVars != null && ctx.currentVars(ordinal) != null && + ctx.currentVars(ordinal).code != EmptyBlock + case _ => false + } + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) + if (eval.code != EmptyBlock && !nonEmptyRefs) { + ctx.genReusedCode(stats, eval) + } else { + eval + } + + case None => + doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) } reduceCodeSize(ctx, eval) if (eval.code.toString.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1959d405ae92c..6bb542ca090ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -422,7 +422,7 @@ class CodegenContext extends Logging { * equivalentExpressions will match the tree containing `col1 + col2` and it will only * be evaluated once. */ - private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + private[sql] val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. @@ -1027,71 +1027,82 @@ class CodegenContext extends Logging { splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) } - def subexpressionElimination(expressions: Expression*): Block = { + def subexpressionElimination(expressions: Seq[Expression]): Block = { var initBlock: Block = EmptyBlock if (SQLConf.get.subexpressionEliminationEnabled) { - // Create a clear EquivalentExpressions and SubExprEliminationState mapping - val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - // Add current expression tree and compute the common subexpressions. - expressions.map(equivalentExpressions.addExprTree(_)) - - commonExpressions.clear() - commonExpressions ++= equivalentExpressions.getAllExprStates(1).map { stats => - val expr = stats.expr - val initialized = addMutableState(JAVA_BOOLEAN, "subExprInit") - initBlock += code"$initialized = false;\n" - val wrapperFunc: ExprCode => ExprCode = { eval => - val (inputVars, exprCodes) = { - val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr) - (inputVars.toSeq, exprCodes.toSeq) - } - val isNullLiteral = eval.isNull match { - case TrueLiteral | FalseLiteral => true - case _ => false - } - val (isNull, isNullEvalCode) = if (!isNullLiteral) { - val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") - } else { - (eval.isNull, "") - } - val value = addMutableState(javaType(expr.dataType), "subExprValue") - val code = if (isValidParamLength(calculateParamLengthFromExprValues(inputVars))) { - // Generate the code for this expression tree and wrap it in a function. - val fnName = freshName("subExpr") - val argList = - inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") - val fn = - s""" - |private void $fnName(${argList.mkString(", ")}) { - | if (!$initialized) { - | ${eval.code} - | $initialized = true; - | $isNullEvalCode - | $value = ${eval.value}; - | } - |} - """.stripMargin - val inputVariables = inputVars.map(_.variableName).mkString(", ") - code"${addNewFunction(fnName, fn)}($inputVariables);" - } else { - code""" - |if (!$initialized) { - | ${eval.code} - | $initialized = true; - | $isNullEvalCode - | $value = ${eval.value}; - |} - """.stripMargin - } - ExprCode(code, isNull, JavaCode.global(value, expr.dataType)) - } - ExpressionEquals(expr) -> (stats.useCount, wrapperFunc, None) - }.toMap + val equivalence = new EquivalentExpressions + wholeStageSubexpressionElimination(expressions, equivalence) + equivalence.getAllExprStates(1).map(initBlock += initCommonExpression(_)) } initBlock } + def wholeStageSubexpressionElimination( + expressions: Seq[Expression], + equivalence: EquivalentExpressions): Unit = { + expressions.map(equivalence.addExprTree(_)) + } + + def initCommonExpression(stats: ExpressionStats): Block = { + if (stats.initialized.isEmpty) { + val expr = stats.expr + stats.initialized = Some(addMutableState(JAVA_BOOLEAN, "subExprInit")) + stats.isNull = Some(JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "subExprIsNull"))) + stats.value = Some(JavaCode.global(addMutableState(javaType(expr.dataType), "subExprValue"), + expr.dataType)) + stats.funcName = Some(freshName("subExpr")) + commonExpressions += ExpressionEquals(expr) -> stats + code"${stats.initialized.get} = false;\n" + } else { + EmptyBlock + } + } + + def genReusedCode(stats: ExpressionStats, eval: ExprCode): ExprCode = { + val (inputVars, _) = getLocalInputVariableValues(this, stats.expr) + val (initialized, isNull, value) = (stats.initialized.get, stats.isNull.get, stats.value.get) + val validParamLength = isValidParamLength(calculateParamLengthFromExprValues(inputVars)) + if(!stats.addedFunction && validParamLength) { + // Generate the code for this expression tree and wrap it in a function. + val argList = + inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") + val fn = + s""" + |private void ${stats.funcName.get}(${argList.mkString(", ")}) { + | if (!$initialized) { + | ${eval.code} + | $initialized = true; + | $isNull = ${eval.isNull}; + | $value = ${eval.value}; + | } + |} + """.stripMargin + stats.funcName = Some(addNewFunction(stats.funcName.get, fn)) + stats.params = Some(inputVars.map(_.javaType)) + stats.addedFunction = true + } + // input vars changed, e.g. some input vars now are GlobalValue. + if (inputVars.map(_.javaType) != stats.params.get) { + eval + } else { + val code = + if (validParamLength) { + val inputVariables = inputVars.map(_.variableName).mkString(", ") + code"${stats.funcName.get}($inputVariables);" + } else { + code""" + |if (!$initialized) { + | ${eval.code} + | $initialized = true; + | $isNull = ${eval.isNull}; + | $value = ${eval.value}; + |} + """.stripMargin + } + ExprCode(code, isNull, value) + } + } + /** * Generates code for expressions. If doSubexpressionElimination is true, subexpression * elimination will be performed. Subexpression elimination assumes that the code for each @@ -1104,7 +1115,7 @@ class CodegenContext extends Logging { // as well because some expressions may implement `CodegenFallback`. val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression()) val initBlock = if (doSubexpressionElimination) { - subexpressionElimination(cleanedExpressions: _*) + subexpressionElimination(cleanedExpressions) } else { EmptyBlock } @@ -1148,8 +1159,7 @@ class CodegenContext extends Logging { } } - private[spark] val commonExpressions = - new mutable.HashMap[ExpressionEquals, (Int, ExprCode => ExprCode, Option[ExprCode])] + var commonExpressions = Map[ExpressionEquals, ExpressionStats]() } /** @@ -1686,9 +1696,9 @@ object CodeGenerator extends Logging { ctx: CodegenContext, expr: Expression, subExprs: Map[ExpressionEquals, SubExprEliminationState] = Map.empty) - : (Set[VariableValue], Set[ExprCode]) = { - val argSet = mutable.Set[VariableValue]() - val exprCodesNeedEvaluate = mutable.Set[ExprCode]() + : (Seq[VariableValue], Seq[ExprCode]) = { + val argSet = mutable.LinkedHashSet[VariableValue]() + val exprCodesNeedEvaluate = mutable.LinkedHashSet[ExprCode]() if (ctx.INPUT_ROW != null) { argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) @@ -1725,7 +1735,7 @@ object CodeGenerator extends Logging { } } - (argSet.toSet, exprCodesNeedEvaluate.toSet) + (argSet.toSeq, exprCodesNeedEvaluate.toSeq) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index d362ac48a1c28..2138dbe89c046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -169,7 +169,7 @@ case class ExpandExec( // Part 2: switch/case statements initBlock += ctx.subexpressionElimination( - projections.flatten.map(BindReferences.bindReference(_, attributeSeq)): _*) + projections.flatten.map(BindReferences.bindReference(_, attributeSeq))) val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) => val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => if (!sameOutput(col)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1b3fe92adea2c..3b8c31af5cdf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -45,6 +45,11 @@ import org.apache.spark.util.Utils */ trait CodegenSupport extends SparkPlan { + def reusableExpressions(): (Seq[Expression], AttributeSet) = (Seq(), AttributeSet.empty) + + var initBlock: Block = EmptyBlock + var commonExpressions = Map[ExpressionEquals, ExpressionStats]() + /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: HashAggregateExec => "hashAgg" @@ -176,6 +181,7 @@ trait CodegenSupport extends SparkPlan { ctx.currentVars = inputVars ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix + ctx.commonExpressions = parent.commonExpressions val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) // Under certain conditions, we can put the logic to consume the rows of this operator into @@ -343,8 +349,6 @@ trait CodegenSupport extends SparkPlan { throw new UnsupportedOperationException } - var initBlock: Block = EmptyBlock - /** * Whether or not the result rows of this operator should be copied before putting into a buffer. * @@ -660,6 +664,50 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) def doCodeGen(): (CodegenContext, CodeAndComment) = { val startTime = System.nanoTime() val ctx = new CodegenContext + + if (SQLConf.get.subexpressionEliminationEnabled) { + val stack = mutable.Stack[SparkPlan](child) + var attributeSet = AttributeSet.empty + val executeSeq = + new mutable.ArrayBuffer[(CodegenSupport, Seq[Expression], EquivalentExpressions)]() + var equivalence = new EquivalentExpressions + while (stack.nonEmpty) { + stack.pop() match { + case _: WholeStageCodegenExec => + case _: InputRDDCodegen => + case c: CodegenSupport => + val (newReusableExpressions, newAttributeSet) = c.reusableExpressions() + // If the input attributes changed, collect current common expressions and clear + // equivalentExpressions + if (!attributeSet.subsetOf(newAttributeSet)) { + equivalence = new EquivalentExpressions + } + if (newReusableExpressions.nonEmpty) { + val bondExpressions = + BindReferences.bindReferences(newReusableExpressions, newAttributeSet.toSeq) + executeSeq += ((c, bondExpressions, equivalence)) + ctx.wholeStageSubexpressionElimination(bondExpressions, equivalence) + } + attributeSet = newAttributeSet + stack.pushAll(c.children) + + case _ => + } + } + executeSeq.reverse.foreach { case (plan, bondExpressions, equivalence) => + val commonExprs = + equivalence.getAllExprStates(1) + .map(stat => ExpressionEquals(stat.expr) -> stat).toMap + plan.commonExpressions = commonExprs + bondExpressions.foreach { + _.foreach { expr => + commonExprs.get(ExpressionEquals(expr)).map { stat => + plan.initBlock += ctx.initCommonExpression(stat) + } + } + } + } + } val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) // main next function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 25ea1b71d279b..65730c410dd7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -207,7 +207,7 @@ trait AggregateCodegenSupport val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten: _*) + val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten) val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } @@ -297,7 +297,7 @@ trait AggregateCodegenSupport aggCodeBlocks: Seq[Block]): Option[Seq[String]] = { val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => val inputVarsForOneFunc = aggExprsForOneFunc.map( - CodeGenerator.getLocalInputVariableValues(ctx, _)._1).reduce(_ ++ _).toSeq + CodeGenerator.getLocalInputVariableValues(ctx, _)._1.toSet).reduce(_ ++ _).toSeq val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a47239865c6c5..699d5f00cfa90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -551,6 +551,7 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" |while ($limitNotReachedCondition $iterTermForFastHashMap.next()) { + | ${initBlock} | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -577,6 +578,7 @@ case class HashAggregateExec( s""" |while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) { | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); + | ${initBlock} | ${generateKeyRow.code} | ${generateBufferRow.code} | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); @@ -591,6 +593,7 @@ case class HashAggregateExec( def outputFromRegularHashMap: String = { s""" |while ($limitNotReachedCondition $iterTerm.next()) { + | ${initBlock} | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -728,7 +731,7 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten: _*) + val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten) val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } @@ -771,7 +774,7 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten: _*) + val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten) val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2c9b61cebcfe7..dbabade51597e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -65,9 +65,12 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) references.filter(a => usedMoreThanOnce.contains(a.exprId)) } + override def reusableExpressions(): (Seq[Expression], AttributeSet) = + (projectList, AttributeSet(child.output)) + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) - initBlock += ctx.subexpressionElimination(exprs: _*) val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. @@ -166,8 +169,6 @@ trait GeneratePredicateHelper extends PredicateHelper { // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() - initBlock += - ctx.subexpressionElimination(otherPreds.map(BindReferences.bindReference(_, inputAttrs)): _*) val generated = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} @@ -236,6 +237,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def reusableExpressions(): (Seq[Expression], AttributeSet) = + (otherPreds, AttributeSet(child.output)) + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 84c0cd127f45a..3114a5125845b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -464,6 +464,7 @@ case class BroadcastNestedLoopJoinExec( s""" |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | ${initBlock} | $checkCondition { | $numOutput.add(1); | ${consume(ctx, resultVars)} @@ -497,6 +498,7 @@ case class BroadcastNestedLoopJoinExec( |boolean $foundMatch = false; |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | ${initBlock} | boolean $shouldOutputRow = false; | $checkCondition { | $shouldOutputRow = true; @@ -548,6 +550,7 @@ case class BroadcastNestedLoopJoinExec( |boolean $foundMatch = false; |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | ${initBlock} | $checkCondition { | $foundMatch = true; | break; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7c48baf99ef83..f3a1f52c612e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -458,6 +458,7 @@ trait HashJoin extends JoinCodegenSupport { val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) s""" + |$initBlock |boolean $conditionPassed = true; |${eval.trim} |if ($matched != null) { @@ -657,6 +658,7 @@ trait HashJoin extends JoinCodegenSupport { val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) s""" + |$initBlock |$eval |${ev.code} |$existsVar = !${ev.isNull} && ${ev.value}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala index 4a11a7b9709dc..29ed51c6247f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala @@ -56,12 +56,12 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { // filter the output via condition ctx.currentVars = streamVars2 ++ buildVars val bondExpr = BindReferences.bindReference(expr, streamPlan.output ++ buildPlan.output) - initBlock += ctx.subexpressionElimination(bondExpr) + val initBlock = ctx.subexpressionElimination(Seq(bondExpr)) val ev = bondExpr.genCode(ctx) val skipRow = s"${ev.isNull} || !${ev.value}" s""" - |$eval |$initBlock + |$eval |${ev.code} |if (!($skipRow)) """.stripMargin From e5ee05ef3457e4fb22c3e01c42b167436b070623 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Thu, 13 Apr 2023 18:34:35 +0800 Subject: [PATCH 03/10] Bug fix for whole stage subexpression elimination --- .../expressions/codegen/CodeGenerator.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 17 +++++++++-------- .../sql/execution/basicPhysicalOperators.scala | 8 ++++---- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6bb542ca090ea..da9a351b41082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1159,7 +1159,7 @@ class CodegenContext extends Logging { } } - var commonExpressions = Map[ExpressionEquals, ExpressionStats]() + var commonExpressions = mutable.Map[ExpressionEquals, ExpressionStats]() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 3b8c31af5cdf3..fe688d4ebbaae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -45,10 +45,10 @@ import org.apache.spark.util.Utils */ trait CodegenSupport extends SparkPlan { - def reusableExpressions(): (Seq[Expression], AttributeSet) = (Seq(), AttributeSet.empty) + def reusableExpressions(): (Seq[Expression], Seq[Attribute]) = (Seq(), Seq()) var initBlock: Block = EmptyBlock - var commonExpressions = Map[ExpressionEquals, ExpressionStats]() + var commonExpressions = mutable.Map.empty[ExpressionEquals, ExpressionStats] /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { @@ -667,7 +667,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) if (SQLConf.get.subexpressionEliminationEnabled) { val stack = mutable.Stack[SparkPlan](child) - var attributeSet = AttributeSet.empty + var attributeSeq = Seq[Attribute]() val executeSeq = new mutable.ArrayBuffer[(CodegenSupport, Seq[Expression], EquivalentExpressions)]() var equivalence = new EquivalentExpressions @@ -676,19 +676,20 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) case _: WholeStageCodegenExec => case _: InputRDDCodegen => case c: CodegenSupport => - val (newReusableExpressions, newAttributeSet) = c.reusableExpressions() + val (newReusableExpressions, newAttributeSeq) = c.reusableExpressions() // If the input attributes changed, collect current common expressions and clear // equivalentExpressions - if (!attributeSet.subsetOf(newAttributeSet)) { + if (attributeSeq.size != newAttributeSeq.size || + attributeSeq.zip(newAttributeSeq).exists(tup => !tup._1.equals(tup._2))) { equivalence = new EquivalentExpressions } if (newReusableExpressions.nonEmpty) { val bondExpressions = - BindReferences.bindReferences(newReusableExpressions, newAttributeSet.toSeq) + BindReferences.bindReferences(newReusableExpressions, newAttributeSeq.toSeq) executeSeq += ((c, bondExpressions, equivalence)) ctx.wholeStageSubexpressionElimination(bondExpressions, equivalence) } - attributeSet = newAttributeSet + attributeSeq = newAttributeSeq stack.pushAll(c.children) case _ => @@ -698,11 +699,11 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) val commonExprs = equivalence.getAllExprStates(1) .map(stat => ExpressionEquals(stat.expr) -> stat).toMap - plan.commonExpressions = commonExprs bondExpressions.foreach { _.foreach { expr => commonExprs.get(ExpressionEquals(expr)).map { stat => plan.initBlock += ctx.initCommonExpression(stat) + plan.commonExpressions += ExpressionEquals(expr) -> stat } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index dbabade51597e..7f8abbea23bcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -65,8 +65,8 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) references.filter(a => usedMoreThanOnce.contains(a.exprId)) } - override def reusableExpressions(): (Seq[Expression], AttributeSet) = - (projectList, AttributeSet(child.output)) + override def reusableExpressions(): (Seq[Expression], Seq[Attribute]) = + (projectList, child.output) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { @@ -237,8 +237,8 @@ case class FilterExec(condition: Expression, child: SparkPlan) child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def reusableExpressions(): (Seq[Expression], AttributeSet) = - (otherPreds, AttributeSet(child.output)) + override def reusableExpressions(): (Seq[Expression], Seq[Attribute]) = + (otherPreds, child.output) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") From f152b2aaa64cd2f0803cbd8f27a2190414314678 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Thu, 13 Apr 2023 22:45:23 +0800 Subject: [PATCH 04/10] Bug fix --- .../sql/catalyst/expressions/Expression.scala | 46 +++++++++++-------- .../sql/execution/WholeStageCodegenExec.scala | 4 +- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 28b2569d9067a..884c9ec3c255c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -196,30 +196,36 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val exprKey = ExpressionEquals(this) - val eval = ctx.commonExpressions.get(exprKey) match { - case Some(stats) => - // We should reuse the currentVar references which code is not empty - val nonEmptyRefs = this.exists { - case BoundReference(ordinal, _, _) => - ctx.currentVars != null && ctx.currentVars(ordinal) != null && - ctx.currentVars(ordinal).code != EmptyBlock - case _ => false - } - val eval = doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) - if (eval.code != EmptyBlock && !nonEmptyRefs) { - ctx.genReusedCode(stats, eval) - } else { - eval - } + val eval = + if (EquivalentExpressions.supportedExpression(this)) { + ctx.commonExpressions.get(ExpressionEquals(this)) match { + case Some(stats) => + // We should reuse the currentVar references which code is not empty + val nonEmptyRefs = this.exists { + case BoundReference(ordinal, _, _) => + ctx.currentVars != null && ctx.currentVars(ordinal) != null && + ctx.currentVars(ordinal).code != EmptyBlock + case _ => false + } + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) + if (eval.code != EmptyBlock && !nonEmptyRefs) { + ctx.genReusedCode(stats, eval) + } else { + eval + } - case None => + case None => + doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) + } + } else { doGenCode(ctx, ExprCode( JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) - } + } reduceCodeSize(ctx, eval) if (eval.code.toString.nonEmpty) { // Add `this` in the comment. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index fe688d4ebbaae..cf1b490255f10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -680,12 +680,12 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) // If the input attributes changed, collect current common expressions and clear // equivalentExpressions if (attributeSeq.size != newAttributeSeq.size || - attributeSeq.zip(newAttributeSeq).exists(tup => !tup._1.equals(tup._2))) { + attributeSeq.zip(newAttributeSeq).exists { case (left, right) => left != right }) { equivalence = new EquivalentExpressions } if (newReusableExpressions.nonEmpty) { val bondExpressions = - BindReferences.bindReferences(newReusableExpressions, newAttributeSeq.toSeq) + BindReferences.bindReferences(newReusableExpressions, newAttributeSeq) executeSeq += ((c, bondExpressions, equivalence)) ctx.wholeStageSubexpressionElimination(bondExpressions, equivalence) } From 9a7cdc6e673cb7fe69255dc6abafe2302ad48ba2 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Fri, 14 Apr 2023 17:51:19 +0800 Subject: [PATCH 05/10] Do not support CSE in produce method --- .../org/apache/spark/sql/execution/WholeStageCodegenExec.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index cf1b490255f10..e10950b4d978b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -708,6 +708,8 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) } } } + // Do not support CSE in produce method. + ctx.commonExpressions.clear() } val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) From 076792feef18842cbaf7fa3acca4d99d8ad19861 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Fri, 14 Apr 2023 19:01:13 +0800 Subject: [PATCH 06/10] Clear stale commonExpressions when reoptimize plan --- .../sql/catalyst/expressions/Expression.scala | 14 ++++++-------- .../sql/execution/WholeStageCodegenExec.scala | 11 +++++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 884c9ec3c255c..2e410cef6fe8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -194,8 +194,6 @@ abstract class Expression extends TreeNode[Expression] { subExprState.eval.isNull, subExprState.eval.value) }.getOrElse { - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") val eval = if (EquivalentExpressions.supportedExpression(this)) { ctx.commonExpressions.get(ExpressionEquals(this)) match { @@ -208,8 +206,8 @@ abstract class Expression extends TreeNode[Expression] { case _ => false } val eval = doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) + JavaCode.isNullVariable(ctx.freshName("isNull")), + JavaCode.variable(ctx.freshName("value"), dataType))) if (eval.code != EmptyBlock && !nonEmptyRefs) { ctx.genReusedCode(stats, eval) } else { @@ -218,13 +216,13 @@ abstract class Expression extends TreeNode[Expression] { case None => doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) + JavaCode.isNullVariable(ctx.freshName("isNull")), + JavaCode.variable(ctx.freshName("value"), dataType))) } } else { doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) + JavaCode.isNullVariable(ctx.freshName("isNull")), + JavaCode.variable(ctx.freshName("value"), dataType))) } reduceCodeSize(ctx, eval) if (eval.code.toString.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index e10950b4d978b..afaa43ce9088f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -675,8 +675,11 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) stack.pop() match { case _: WholeStageCodegenExec => case _: InputRDDCodegen => - case c: CodegenSupport => - val (newReusableExpressions, newAttributeSeq) = c.reusableExpressions() + case plan: CodegenSupport => + // Because this plan may already be optimized before, so remove stale commonExpressions + plan.initBlock = EmptyBlock + plan.commonExpressions.clear() + val (newReusableExpressions, newAttributeSeq) = plan.reusableExpressions() // If the input attributes changed, collect current common expressions and clear // equivalentExpressions if (attributeSeq.size != newAttributeSeq.size || @@ -686,11 +689,11 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) if (newReusableExpressions.nonEmpty) { val bondExpressions = BindReferences.bindReferences(newReusableExpressions, newAttributeSeq) - executeSeq += ((c, bondExpressions, equivalence)) + executeSeq += ((plan, bondExpressions, equivalence)) ctx.wholeStageSubexpressionElimination(bondExpressions, equivalence) } attributeSeq = newAttributeSeq - stack.pushAll(c.children) + stack.pushAll(plan.children) case _ => } From f524ef51b5b118796225a120f81d7fe579f23f28 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Tue, 9 May 2023 19:59:20 +0800 Subject: [PATCH 07/10] Remove whole-stage expression elimination --- .../expressions/codegen/CodeGenerator.scala | 13 +++-- .../sql/execution/WholeStageCodegenExec.scala | 52 ------------------- .../execution/basicPhysicalOperators.scala | 10 ++-- 3 files changed, 9 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index da9a351b41082..0afd788c2c2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1027,22 +1027,21 @@ class CodegenContext extends Logging { splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) } + /** + * Collect all commons expressions and return the initialization code block. + * @param expressions + * @return + */ def subexpressionElimination(expressions: Seq[Expression]): Block = { var initBlock: Block = EmptyBlock if (SQLConf.get.subexpressionEliminationEnabled) { val equivalence = new EquivalentExpressions - wholeStageSubexpressionElimination(expressions, equivalence) + expressions.map(equivalence.addExprTree(_)) equivalence.getAllExprStates(1).map(initBlock += initCommonExpression(_)) } initBlock } - def wholeStageSubexpressionElimination( - expressions: Seq[Expression], - equivalence: EquivalentExpressions): Unit = { - expressions.map(equivalence.addExprTree(_)) - } - def initCommonExpression(stats: ExpressionStats): Block = { if (stats.initialized.isEmpty) { val expr = stats.expr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index afaa43ce9088f..556299d46cf28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -45,8 +45,6 @@ import org.apache.spark.util.Utils */ trait CodegenSupport extends SparkPlan { - def reusableExpressions(): (Seq[Expression], Seq[Attribute]) = (Seq(), Seq()) - var initBlock: Block = EmptyBlock var commonExpressions = mutable.Map.empty[ExpressionEquals, ExpressionStats] @@ -664,56 +662,6 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) def doCodeGen(): (CodegenContext, CodeAndComment) = { val startTime = System.nanoTime() val ctx = new CodegenContext - - if (SQLConf.get.subexpressionEliminationEnabled) { - val stack = mutable.Stack[SparkPlan](child) - var attributeSeq = Seq[Attribute]() - val executeSeq = - new mutable.ArrayBuffer[(CodegenSupport, Seq[Expression], EquivalentExpressions)]() - var equivalence = new EquivalentExpressions - while (stack.nonEmpty) { - stack.pop() match { - case _: WholeStageCodegenExec => - case _: InputRDDCodegen => - case plan: CodegenSupport => - // Because this plan may already be optimized before, so remove stale commonExpressions - plan.initBlock = EmptyBlock - plan.commonExpressions.clear() - val (newReusableExpressions, newAttributeSeq) = plan.reusableExpressions() - // If the input attributes changed, collect current common expressions and clear - // equivalentExpressions - if (attributeSeq.size != newAttributeSeq.size || - attributeSeq.zip(newAttributeSeq).exists { case (left, right) => left != right }) { - equivalence = new EquivalentExpressions - } - if (newReusableExpressions.nonEmpty) { - val bondExpressions = - BindReferences.bindReferences(newReusableExpressions, newAttributeSeq) - executeSeq += ((plan, bondExpressions, equivalence)) - ctx.wholeStageSubexpressionElimination(bondExpressions, equivalence) - } - attributeSeq = newAttributeSeq - stack.pushAll(plan.children) - - case _ => - } - } - executeSeq.reverse.foreach { case (plan, bondExpressions, equivalence) => - val commonExprs = - equivalence.getAllExprStates(1) - .map(stat => ExpressionEquals(stat.expr) -> stat).toMap - bondExpressions.foreach { - _.foreach { expr => - commonExprs.get(ExpressionEquals(expr)).map { stat => - plan.initBlock += ctx.initCommonExpression(stat) - plan.commonExpressions += ExpressionEquals(expr) -> stat - } - } - } - } - // Do not support CSE in produce method. - ctx.commonExpressions.clear() - } val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) // main next function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 7f8abbea23bcd..2c9b61cebcfe7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -65,12 +65,9 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) references.filter(a => usedMoreThanOnce.contains(a.exprId)) } - override def reusableExpressions(): (Seq[Expression], Seq[Attribute]) = - (projectList, child.output) - - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) + initBlock += ctx.subexpressionElimination(exprs: _*) val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. @@ -169,6 +166,8 @@ trait GeneratePredicateHelper extends PredicateHelper { // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() + initBlock += + ctx.subexpressionElimination(otherPreds.map(BindReferences.bindReference(_, inputAttrs)): _*) val generated = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} @@ -237,9 +236,6 @@ case class FilterExec(condition: Expression, child: SparkPlan) child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def reusableExpressions(): (Seq[Expression], Seq[Attribute]) = - (otherPreds, child.output) - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") From 670446a40ef99980975571cbf4e472528ed8a791 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Wed, 10 May 2023 18:54:07 +0800 Subject: [PATCH 08/10] Merge code into origin CSE --- .../expressions/EquivalentExpressions.scala | 105 ++++++- .../expressions/codegen/CodeGenerator.scala | 266 +++++++++++++++++- .../expressions/CodeGenerationSuite.scala | 61 ++++ .../SubexpressionEliminationSuite.scala | 140 +++++++-- .../spark/sql/execution/ExpandExec.scala | 2 +- .../aggregate/AggregateCodegenSupport.scala | 109 +++---- .../aggregate/HashAggregateExec.scala | 22 +- .../execution/basicPhysicalOperators.scala | 21 +- .../execution/joins/JoinCodegenSupport.scala | 2 +- .../execution/WholeStageCodegenSuite.scala | 22 ++ 10 files changed, 663 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index d8f7368038a12..38a61196d0dc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -22,8 +22,10 @@ import java.util.Objects import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.supportedExpression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.codegen.ExprValue import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils /** @@ -31,7 +33,9 @@ import org.apache.spark.util.Utils * to this class and they subsequently query for expression equality. Expression trees are * considered equal if for the same input(s), the same result is produced. */ -class EquivalentExpressions { +class EquivalentExpressions( + skipForShortcutEnable: Boolean = SQLConf.get.subexpressionEliminationSkipForShotcutExpr) { + // For each expression, the set of equivalent expressions. private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] @@ -89,6 +93,78 @@ class EquivalentExpressions { } } + /** + * Adds or removes only expressions which are common in each of given expressions, in a recursive + * way. + * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common + * expression `(c + 1)` will be added into `equivalenceMap`. + * + * Note that as we don't know in advance if any child node of an expression will be common across + * all given expressions, we compute local equivalence maps for all given expressions and filter + * only the common nodes. + * Those common nodes are then removed from the local map and added to the final map of + * expressions. + */ + private def updateCommonExprs( + exprs: Seq[Expression], + map: mutable.HashMap[ExpressionEquals, ExpressionStats], + useCount: Int): Unit = { + assert(exprs.length > 1) + var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] + updateExprTree(exprs.head, localEquivalenceMap) + + exprs.tail.foreach { expr => + val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] + updateExprTree(expr, otherLocalEquivalenceMap) + localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => + otherLocalEquivalenceMap.contains(key) + } + } + + // Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`. + // The remaining highest expression in `localEquivalenceMap` is also common expression so loop + // until `localEquivalenceMap` is not empty. + var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) + while (statsOption.nonEmpty) { + val stats = statsOption.get + updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount) + updateExprTree(stats.expr, map, useCount) + + statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) + } + } + + private def skipForShortcut(expr: Expression): Expression = { + if (skipForShortcutEnable) { + // The subexpression may not need to eval even if it appears more than once. + // e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true. + expr match { + case and: And => and.left + case or: Or => or.left + case other => other + } + } else { + expr + } + } + + // There are some special expressions that we should not recurse into all of its children. + // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) + // 2. ConditionalExpression: use its children that will always be evaluated. + private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { + case _: CodegenFallback => Nil + case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) + case other => skipForShortcut(other).children + } + + // For some special expressions we cannot just recurse into all of its children, but we can + // recursively add the common expressions shared between all of its children. + private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { + case _: CodegenFallback => Nil + case c: ConditionalExpression => c.branchGroups + case _ => Nil + } + /** * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. @@ -109,7 +185,32 @@ class EquivalentExpressions { if (!skip && !updateExprInMap(expr, map, useCount)) { val uc = useCount.signum - expr.children.foreach(updateExprTree(_, map, uc)) + childrenToRecurse(expr).foreach(updateExprTree(_, map, uc)) + commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc)) + } + } + + /** + * Adds the expression to this data structure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + */ + def addConditionalExprTree( + expr: Expression, + map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = { + if (supportedExpression(expr)) { + updateConditionalExprTree(expr, map) + } + } + + private def updateConditionalExprTree( + expr: Expression, + map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap, + useCount: Int = 1): Unit = { + val skip = useCount == 0 || expr.isInstanceOf[LeafExpression] + + if (!skip && !updateExprInMap(expr, map, useCount)) { + val uc = useCount.signum + expr.children.foreach(updateConditionalExprTree(_, map, uc)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0afd788c2c2c4..7945fcf51005e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -422,7 +422,7 @@ class CodegenContext extends Logging { * equivalentExpressions will match the tree containing `col1 + col2` and it will only * be evaluated once. */ - private[sql] val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. @@ -1028,16 +1028,265 @@ class CodegenContext extends Logging { } /** - * Collect all commons expressions and return the initialization code block. + * Perform a function which generates a sequence of ExprCodes with a given mapping between + * expressions and common expressions, instead of using the mapping in current context. + */ + def withSubExprEliminationExprs( + newSubExprEliminationExprs: Map[ExpressionEquals, SubExprEliminationState])( + f: => Seq[ExprCode]): Seq[ExprCode] = { + val oldsubExprEliminationExprs = subExprEliminationExprs + subExprEliminationExprs = newSubExprEliminationExprs + + val genCodes = f + + // Restore previous subExprEliminationExprs + subExprEliminationExprs = oldsubExprEliminationExprs + genCodes + } + + /** + * Evaluates a sequence of `SubExprEliminationState` which represent subexpressions. After + * evaluating a subexpression, this method will clean up the code block to avoid duplicate + * evaluation. + */ + def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = { + val code = new StringBuilder() + + subExprStates.foreach { state => + val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code + code.append(currentCode + "\n") + state.eval.code = EmptyBlock + } + + code.toString() + } + + /** + * Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen. + * + * This finds the common subexpressions, generates the code snippets that evaluate those + * expressions and populates the mapping of common subexpressions to the generated code snippets. + * + * The generated code snippet for subexpression is wrapped in `SubExprEliminationState`, which + * contains an `ExprCode` and the children `SubExprEliminationState` if any. The `ExprCode` + * includes java source code, result variable name and is-null variable name of the subexpression. + * + * Besides, this also returns a sequences of `ExprCode` which are expression codes that need to + * be evaluated (as their input parameters) before evaluating subexpressions. + * + * To evaluate the returned subexpressions, please call `evaluateSubExprEliminationState` with + * the `SubExprEliminationState`s to be evaluated. During generating the code, it will cleanup + * the states to avoid duplicate evaluation. + * + * The details of subexpression generation: + * 1. Gets subexpression set. See `EquivalentExpressions`. + * 2. Generate code of subexpressions as a whole block of code (non-split case) + * 3. Check if the total length of the above block is larger than the split-threshold. If so, + * try to split it in step 4, otherwise returning the non-split code block. + * 4. Check if parameter lengths of all subexpressions satisfy the JVM limitation, if so, + * try to split, otherwise returning the non-split code block. + * 5. For each subexpression, generating a function and put the code into it. To evaluate the + * subexpression, just call the function. + * + * The explanation of subexpression codegen: + * 1. Wrapping in `withSubExprEliminationExprs` call with current subexpression map. Each + * subexpression may depends on other subexpressions (children). So when generating code + * for subexpressions, we iterate over each subexpression and put the mapping between + * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression + * evaluation, we can look for generated subexpressions and do replacement. + */ + def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { + // Create a clear EquivalentExpressions and SubExprEliminationState mapping + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + val localSubExprEliminationExprsForNonSplit = + mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] + + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the expressions that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getCommonSubexpressions + + val nonSplitCode = { + val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] + commonExprs.map { expr => + withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { + val eval = expr.genCode(this) + // Collects other subexpressions from the children. + val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] + expr.foreach { e => + subExprEliminationExprs.get(ExpressionEquals(e)) match { + case Some(state) => childrenSubExprs += state + case _ => + } + } + val state = SubExprEliminationState(eval, childrenSubExprs.toSeq) + localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state) + allStates += state + Seq(eval) + } + } + allStates.toSeq + } + + // For some operators, they do not require all its child's outputs to be evaluated in advance. + // Instead it only early evaluates part of outputs, for example, `ProjectExec` only early + // evaluate the outputs used more than twice. So we need to extract these variables used by + // subexpressions and evaluate them before subexpressions. + val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr => + val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr) + (inputVars.toSeq, exprCodes.toSeq) + }.unzip + + val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold + val (subExprsMap, exprCodes) = if (needSplit) { + if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { + val localSubExprEliminationExprs = + mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] + + commonExprs.zipWithIndex.foreach { case (expr, i) => + val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { + Seq(expr.genCode(this)) + }.head + + val value = addMutableState(javaType(expr.dataType), "subExprValue") + + val isNullLiteral = eval.isNull match { + case TrueLiteral | FalseLiteral => true + case _ => false + } + val (isNull, isNullEvalCode) = if (!isNullLiteral) { + val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") + } else { + (eval.isNull, "") + } + + // Generate the code for this expression tree and wrap it in a function. + val fnName = freshName("subExpr") + val inputVars = inputVarsForAllFuncs(i) + val argList = + inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") + val fn = + s""" + |private void $fnName(${argList.mkString(", ")}) { + | ${eval.code} + | $isNullEvalCode + | $value = ${eval.value}; + |} + """.stripMargin + + // Collects other subexpressions from the children. + val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] + expr.foreach { e => + localSubExprEliminationExprs.get(ExpressionEquals(e)) match { + case Some(state) => childrenSubExprs += state + case _ => + } + } + + val inputVariables = inputVars.map(_.variableName).mkString(", ") + val code = code"${addNewFunction(fnName, fn)}($inputVariables);" + val state = SubExprEliminationState( + ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + childrenSubExprs.toSeq) + localSubExprEliminationExprs.put(ExpressionEquals(expr), state) + } + (localSubExprEliminationExprs, exprCodesNeedEvaluate) + } else { + val errMsg = "Failed to split subexpression code into small functions because " + + "the parameter length of at least one split function went over the JVM limit: " + + MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) + } else { + logInfo(errMsg) + (localSubExprEliminationExprsForNonSplit, Seq.empty) + } + } + } else { + (localSubExprEliminationExprsForNonSplit, Seq.empty) + } + SubExprCodes(subExprsMap.toMap, exprCodes.flatten) + } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpressions, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]): Unit = { + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the expressions that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getCommonSubexpressions + commonExprs.foreach { expr => + val fnName = freshName("subExpr") + val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val value = addMutableState(javaType(expr.dataType), "subExprValue") + + // Generate the code for this expression tree and wrap it in a function. + val eval = expr.genCode(this) + val fn = + s""" + |private void $fnName(InternalRow $INPUT_ROW) { + | ${eval.code} + | $isNull = ${eval.isNull}; + | $value = ${eval.value}; + |} + """.stripMargin + + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + + val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);" + subexprFunctions += subExprCode + val state = SubExprEliminationState( + ExprCode(code"$subExprCode", + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, expr.dataType))) + subExprEliminationExprs += ExpressionEquals(expr) -> state + } + } + + /** + * If includeDefiniteExpression is true, collect all commons expressions whether or not the + * expressions will definite be executed and return the initialization code block. + * If includeDefiniteExpression is false, we will exclude the common expressions which will + * definite be executed. * @param expressions * @return */ - def subexpressionElimination(expressions: Seq[Expression]): Block = { + def conditionalSubexpressionElimination( + expressions: Seq[Expression], + includeDefiniteExpression: Boolean = true): Block = { var initBlock: Block = EmptyBlock - if (SQLConf.get.subexpressionEliminationEnabled) { - val equivalence = new EquivalentExpressions - expressions.map(equivalence.addExprTree(_)) - equivalence.getAllExprStates(1).map(initBlock += initCommonExpression(_)) + if (!SQLConf.get.subexpressionEliminationEnabled) return initBlock + + val equivalence = new EquivalentExpressions + expressions.map(equivalence.addConditionalExprTree(_)) + val commonExpressions = equivalence.getAllExprStates(1) + if (includeDefiniteExpression) { + commonExpressions.map(initBlock += initCommonExpression(_)) + } else { + val definiteEquivalence = new EquivalentExpressions + expressions.foreach(definiteEquivalence.addExprTree(_)) + (commonExpressions diff definiteEquivalence.getAllExprStates(1)) + .map(initBlock += initCommonExpression(_)) } initBlock } @@ -1115,6 +1364,7 @@ class CodegenContext extends Logging { val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression()) val initBlock = if (doSubexpressionElimination) { subexpressionElimination(cleanedExpressions) + conditionalSubexpressionElimination(cleanedExpressions, false) } else { EmptyBlock } @@ -1688,7 +1938,7 @@ object CodeGenerator extends Logging { * elimination states for a given `expr`. This result will be used to split the * generated code of expressions into multiple functions. * - * Second value: Returns the set of `ExprCodes`s which are necessary codes before + * Second value: Returns the seq of `ExprCodes`s which are necessary codes before * evaluating subexpressions. */ def getLocalInputVariableValues( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7e73becbe0dfe..265b0eeb8bdf8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -463,6 +463,55 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { private def wrap(expr: Expression): ExpressionEquals = ExpressionEquals(expr) + test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { + + val ref = BoundReference(0, IntegerType, true) + val add1 = Add(ref, ref) + val add2 = Add(add1, add1) + val dummy = SubExprEliminationState( + ExprCode(EmptyBlock, + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType))) + + // raw testing of basic functionality + { + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState( + ExprCode(EmptyBlock, e.isNull, e.value)) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(wrap(add1))) + assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) + } + + // emulate an actual codegen workload + { + val ctx = new CodegenContext + // before + ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE + assert(ctx.subExprEliminationExprs.contains(wrap(add1))) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(wrap(ref) -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(wrap(add1))) + assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) + } + } + test("SPARK-23986: freshName can generate duplicated names") { val ctx = new CodegenContext val names1 = ctx.freshName("myName1") :: ctx.freshName("myName1") :: @@ -487,6 +536,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { .exists(_.getMessage().getFormattedMessage.contains("Generated method too long"))) } + test("SPARK-28916: subexpression elimination can cause 64kb code limit on UnsafeProjection") { + val numOfExprs = 10000 + val exprs = (0 to numOfExprs).flatMap(colIndex => + Seq(Add(BoundReference(colIndex, DoubleType, true), + BoundReference(numOfExprs + colIndex, DoubleType, true)), + Add(BoundReference(colIndex, DoubleType, true), + BoundReference(numOfExprs + colIndex, DoubleType, true)))) + // these should not fail to compile due to 64K limit + GenerateUnsafeProjection.generate(exprs, true) + GenerateMutableProjection.generate(exprs, true) + } + test("SPARK-32624: Use CodeGenerator.typeName() to fix byte[] compile issue") { val ctx = new CodegenContext val bytes = new Array[Byte](3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 0566972b29005..f369635a32671 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{DataType, IntegerType, ObjectType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -146,15 +147,15 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence = new EquivalentExpressions equivalence.addExprTree(add) // the `two` inside `fallback` should not be added - assert(equivalence.getAllExprStates(1).size == 1) - assert(equivalence.getAllExprStates().count(_.useCount == 1) == 2) // add, two, explode + assert(equivalence.getAllExprStates(1).size == 0) + assert(equivalence.getAllExprStates().count(_.useCount == 1) == 3) // add, two, explode } test("Children of conditional expressions: If") { val add = Add(Literal(1), Literal(2)) val condition = GreaterThan(add, Literal(3)) - val ifExpr1 = If(condition, add, Literal(1)) + val ifExpr1 = If(condition, add, add) val equivalence1 = new EquivalentExpressions equivalence1.addExprTree(ifExpr1) @@ -171,27 +172,28 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(ifExpr2) - assert(equivalence2.getAllExprStates(1).nonEmpty) - assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 4) + assert(equivalence2.getAllExprStates(1).isEmpty) + assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 3) val ifExpr3 = If(condition, ifExpr1, ifExpr1) val equivalence3 = new EquivalentExpressions equivalence3.addExprTree(ifExpr3) // `add`: 2, `condition`: 2 - assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 3) + assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 2) assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq condition)) assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq add)) // `ifExpr1`, `ifExpr3` - assert(equivalence3.getAllExprStates().count(_.useCount == 1) == 1) + assert(equivalence3.getAllExprStates().count(_.useCount == 1) == 2) + assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1)) assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr3)) } test("Children of conditional expressions: CaseWhen") { val add1 = Add(Literal(1), Literal(2)) val add2 = Add(Literal(2), Literal(3)) - val conditions1 = (GreaterThan(add1, Literal(3)), add1) :: + val conditions1 = (GreaterThan(add2, Literal(3)), add1) :: (GreaterThan(add2, Literal(4)), add1) :: (GreaterThan(add2, Literal(5)), add1) :: Nil @@ -213,7 +215,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel // `add1` is repeatedly in all branch values, and first predicate. assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add1) // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values. val conditions3 = (GreaterThan(add1, Literal(3)), add2) :: @@ -238,7 +240,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(coalesceExpr1) // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 0) + assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) // Negative case. `add1` and `add2` both are not used in all branches. val conditions2 = GreaterThan(add1, Literal(3)) :: @@ -249,7 +252,62 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(coalesceExpr2) - assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 0) + } + + test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") { + withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1") { + val str = BoundReference(0, BinaryType, false) + val pos = BoundReference(1, IntegerType, false) + + val substr = new Substring(str, pos) + + val add = Add(Length(substr), Literal(1)) + val add2 = Add(Length(substr), Literal(2)) + + val ctx = new CodegenContext() + val exprs = Seq(add, add2) + + val oneVar = ctx.freshVariable("str", BinaryType) + val twoVar = ctx.freshVariable("pos", IntegerType) + ctx.addMutableState("byte[]", oneVar, forceInline = true, useFreshName = false) + ctx.addMutableState("int", twoVar, useFreshName = false) + + ctx.INPUT_ROW = null + ctx.currentVars = Seq( + ExprCode(TrueLiteral, oneVar), + ExprCode(TrueLiteral, twoVar)) + + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + ctx.withSubExprEliminationExprs(subExprs.states) { + exprs.map(_.genCode(ctx)) + } + val subExprsCode = ctx.evaluateSubExprEliminationState(subExprs.states.values) + + val codeBody = s""" + public java.lang.Object generate(Object[] references) { + return new TestCode(references); + } + + class TestCode { + ${ctx.declareMutableStates()} + + public TestCode(Object[] references) { + } + + public void initialize(int partitionIndex) { + ${subExprsCode} + } + + ${ctx.declareAddedFunctions()} + } + """ + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + + CodeGenerator.compile(code) + } } test("SPARK-35410: SubExpr elimination should not include redundant child exprs " + @@ -265,7 +323,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val commonExprs = equivalence.getAllExprStates(1) assert(commonExprs.size == 1) - assert(commonExprs.head.useCount == 3) + assert(commonExprs.head.useCount == 2) assert(commonExprs.head.expr eq add3) } @@ -279,8 +337,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(ifExpr3) val commonExprs = equivalence.getAllExprStates(1) - assert(commonExprs.size == 2) - assert(commonExprs.head.useCount == 4) + assert(commonExprs.size == 1) + assert(commonExprs.head.useCount == 2) assert(commonExprs.head.expr eq add) } @@ -341,7 +399,28 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(caseWhenExpr) // `add1` is not in the elseValue, so we can't extract it from the branches - assert(equivalence.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence.getAllExprStates().count(_.useCount == 2) == 0) + } + + test("SPARK-35829: SubExprEliminationState keeps children sub exprs") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(add1, add1) + + val exprs = Seq(add1, add1, add2, add2) + val ctx = new CodegenContext() + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + + val add2State = subExprs.states(ExpressionEquals(add2)) + val add1State = subExprs.states(ExpressionEquals(add1)) + assert(add2State.children.contains(add1State)) + + subExprs.states.values.foreach { state => + assert(state.eval.code != EmptyBlock) + } + ctx.evaluateSubExprEliminationState(subExprs.states.values) + subExprs.states.values.foreach { state => + assert(state.eval.code == EmptyBlock) + } } test("SPARK-38333: PlanExpression expression should skip addExprTree function in Executor") { @@ -364,7 +443,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val n1 = NaNvl(Literal(1.0d), Add(add, add)) val e1 = new EquivalentExpressions e1.addExprTree(n1) - assert(e1.getCommonSubexpressions.nonEmpty) + assert(e1.getCommonSubexpressions.isEmpty) val n2 = NaNvl(add, add) val e2 = new EquivalentExpressions @@ -388,6 +467,33 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val cseState = equivalence.getExprState(expr) assert(hasMatching == cseState.isDefined) } + + test("SPARK-42815: Subexpression elimination support shortcut conditional expression") { + val add = Add(Literal(1), Literal(0)) + val equal = EqualTo(add, add) + + def checkShortcut(expr: Expression, numCommonExpr: Int): Unit = { + val e1 = If(expr, Literal(1), Literal(2)) + val ee1 = new EquivalentExpressions(true) + ee1.addExprTree(e1) + assert(ee1.getCommonSubexpressions.size == numCommonExpr) + + val e2 = expr + val ee2 = new EquivalentExpressions(true) + ee2.addExprTree(e2) + assert(ee2.getCommonSubexpressions.size == numCommonExpr) + } + + // shortcut right child + checkShortcut(And(Literal(false), equal), 0) + checkShortcut(Or(Literal(true), equal), 0) + checkShortcut(Not(And(Literal(true), equal)), 0) + + // always eliminate subexpression for left child + checkShortcut((And(equal, Literal(false))), 1) + checkShortcut(Or(equal, Literal(true)), 1) + checkShortcut(Not(And(equal, Literal(false))), 1) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 2138dbe89c046..847ca85d9beeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -168,7 +168,7 @@ case class ExpandExec( } // Part 2: switch/case statements - initBlock += ctx.subexpressionElimination( + initBlock += ctx.conditionalSubexpressionElimination( projections.flatten.map(BindReferences.bindReference(_, attributeSeq))) val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) => val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 65730c410dd7a..bd59c0d21fd1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, ExpressionEquals, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -207,9 +207,13 @@ trait AggregateCodegenSupport val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val initBlock = ctx.conditionalSubexpressionElimination(boundUpdateExprs.flatten, false) + val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + } } val aggNames = functions.map(_.prettyName) @@ -233,11 +237,12 @@ trait AggregateCodegenSupport } val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( - ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks) + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) s""" |// do aggregate |// common sub-expressions |$initBlock + |$effectiveCodes |// evaluate aggregate functions and update aggregation buffers |$codeToEvalAggFuncs """.stripMargin @@ -252,21 +257,19 @@ trait AggregateCodegenSupport inputAttrs: Seq[Attribute], boundUpdateExprs: Seq[Seq[Expression]], aggNames: Seq[String], - aggCodeBlocks: Seq[Block]): String = { - val evaluated = boundUpdateExprs.flatten.map { e => - evaluateRequiredVariables(inputAttrs, input, e.references) - }.mkString("", "\n", "\n") + aggCodeBlocks: Seq[Block], + subExprs: SubExprCodes): String = { val aggCodes = if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val maybeSplitCodes = splitAggregateExpressions( - ctx, aggNames, boundUpdateExprs, aggCodeBlocks) + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code)) } else { aggCodeBlocks.map(_.code) } - evaluated + aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map { + aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map { case (aggCode, (Partial | Complete, Some(condition))) => // Note: wrap in "do { } while(false);", so the generated checks can jump out // with "continue;" @@ -294,49 +297,59 @@ trait AggregateCodegenSupport ctx: CodegenContext, aggNames: Seq[String], aggBufferUpdatingExprs: Seq[Seq[Expression]], - aggCodeBlocks: Seq[Block]): Option[Seq[String]] = { - val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => - val inputVarsForOneFunc = aggExprsForOneFunc.map( - CodeGenerator.getLocalInputVariableValues(ctx, _)._1.toSet).reduce(_ ++ _).toSeq - val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) + aggCodeBlocks: Seq[Block], + subExprs: Map[ExpressionEquals, SubExprEliminationState]): Option[Seq[String]] = { + val exprValsInSubExprs = subExprs.flatMap { case (_, s) => + s.eval.value :: s.eval.isNull :: Nil + } + if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { + // `SimpleExprValue`s cannot be used as an input variable for split functions, so + // we give up splitting functions if it exists in `subExprs`. + None + } else { + val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => + val inputVarsForOneFunc = aggExprsForOneFunc.map( + CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1.toSet).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) - // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit - if (CodeGenerator.isValidParamLength(paramLength)) { - Some(inputVarsForOneFunc) - } else { - None + // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit + if (CodeGenerator.isValidParamLength(paramLength)) { + Some(inputVarsForOneFunc) + } else { + None + } } - } - // Checks if all the aggregate code can be split into pieces. - // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, - // we totally give up splitting aggregate code. - if (inputVars.forall(_.isDefined)) { - val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => - val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") - val argList = args.map { v => - s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" - }.mkString(", ") - val doAggFuncName = ctx.addNewFunction(doAggFunc, - s""" - |private void $doAggFunc($argList) throws java.io.IOException { - | ${aggCodeBlocks(i)} - |} - """.stripMargin) + // Checks if all the aggregate code can be split into pieces. + // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, + // we totally give up splitting aggregate code. + if (inputVars.forall(_.isDefined)) { + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => + val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") + val argList = args.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + }.mkString(", ") + val doAggFuncName = ctx.addNewFunction(doAggFunc, + s""" + |private void $doAggFunc($argList) throws java.io.IOException { + | ${aggCodeBlocks(i)} + |} + """.stripMargin) - val inputVariables = args.map(_.variableName).mkString(", ") - s"$doAggFuncName($inputVariables);" - } - Some(splitCodes) - } else { - val errMsg = "Failed to split aggregate code into small functions because the parameter " + - "length of at least one split function went over the JVM limit: " + - CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH - if (Utils.isTesting) { - throw new IllegalStateException(errMsg) + val inputVariables = args.map(_.variableName).mkString(", ") + s"$doAggFuncName($inputVariables);" + } + Some(splitCodes) } else { - logInfo(errMsg) - None + val errMsg = "Failed to split aggregate code into small functions because the parameter " + + "length of at least one split function went over the JVM limit: " + + CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) + } else { + logInfo(errMsg) + None + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 699d5f00cfa90..b62b80a8aafdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -731,9 +731,13 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val initBlock = ctx.conditionalSubexpressionElimination(boundUpdateExprs.flatten, false) + val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + } } val aggCodeBlocks = updateExprs.indices.map { i => @@ -758,10 +762,11 @@ case class HashAggregateExec( } val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( - ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks) + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) s""" |// common sub-expressions |$initBlock + |$effectiveCodes |// evaluate aggregate functions and update aggregation buffers |$codeToEvalAggFuncs """.stripMargin @@ -774,9 +779,13 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val initBlock = ctx.subexpressionElimination(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val initBlock = ctx.conditionalSubexpressionElimination(boundUpdateExprs.flatten, false) + val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + } } val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) => @@ -800,7 +809,7 @@ case class HashAggregateExec( } val codeToEvalAggFuncs = generateEvalCodeForAggFuncs( - ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks) + ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) // If vectorized fast hash map is on, we first generate code to update row // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. @@ -809,6 +818,7 @@ case class HashAggregateExec( |if ($fastRowBuffer != null) { | // common sub-expressions | $initBlock + | $effectiveCodes | // evaluate aggregate functions and update aggregation buffers | $codeToEvalAggFuncs |} else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2c9b61cebcfe7..c2bf013d48129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -67,12 +67,25 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) - initBlock += ctx.subexpressionElimination(exprs: _*) - val resultVars = exprs.map(_.genCode(ctx)) + val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { + // subexpression elimination + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + initBlock += ctx.conditionalSubexpressionElimination(exprs, false) + val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { + exprs.map(_.genCode(ctx)) + } + (ctx.evaluateSubExprEliminationState(subExprs.states.values), genVars, + subExprs.exprCodesNeedEvaluate) + } else { + ("", exprs.map(_.genCode(ctx)), Seq.empty) + } // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" + |// common sub-expressions + |${evaluateVariables(localValInputs)} + |$subExprsCode |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} |${consume(ctx, resultVars)} """.stripMargin @@ -166,8 +179,8 @@ trait GeneratePredicateHelper extends PredicateHelper { // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() - initBlock += - ctx.subexpressionElimination(otherPreds.map(BindReferences.bindReference(_, inputAttrs)): _*) + initBlock += ctx.conditionalSubexpressionElimination( + otherPreds.map(BindReferences.bindReference(_, inputAttrs))) val generated = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala index 29ed51c6247f2..43fa0cdb5316e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala @@ -56,7 +56,7 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { // filter the output via condition ctx.currentVars = streamVars2 ++ buildVars val bondExpr = BindReferences.bindReference(expr, streamPlan.output ++ buildPlan.output) - val initBlock = ctx.subexpressionElimination(Seq(bondExpr)) + val initBlock = ctx.conditionalSubexpressionElimination(Seq(bondExpr)) val ev = bondExpr.genCode(ctx) val skipRow = s"${ev.isNull} || !${ev.value}" s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 72ed5d5b70e4d..ac710c3229647 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -770,4 +770,26 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } } + + test("Give up splitting subexpression code if a parameter length goes over the limit") { + withSQLConf( + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", + "spark.sql.CodeGenerator.validParamLength" -> "0") { + withTable("t") { + val expectedErrMsg = "Failed to split subexpression code into small functions" + Seq( + // Test case without keys + "SELECT AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1)) t(a, b, c)", + // Tet case with keys + "SELECT k, AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1, 1)) t(k, a, b, c) " + + "GROUP BY k").foreach { query => + val e = intercept[IllegalStateException] { + sql(query).collect + } + assert(e.getMessage.contains(expectedErrMsg)) + } + } + } + } } From fbfa00b6e7fca682030a448c2d64e03cff595369 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Wed, 10 May 2023 23:24:28 +0800 Subject: [PATCH 09/10] bug fix --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7945fcf51005e..bbe7b1cccd87f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1307,7 +1307,7 @@ class CodegenContext extends Logging { } def genReusedCode(stats: ExpressionStats, eval: ExprCode): ExprCode = { - val (inputVars, _) = getLocalInputVariableValues(this, stats.expr) + val (inputVars, _) = getLocalInputVariableValues(this, stats.expr, subExprEliminationExprs) val (initialized, isNull, value) = (stats.initialized.get, stats.isNull.get, stats.value.get) val validParamLength = isValidParamLength(calculateParamLengthFromExprValues(inputVars)) if(!stats.addedFunction && validParamLength) { From 06238be1eaa152cc26f5a12485133c8bebd70de2 Mon Sep 17 00:00:00 2001 From: Kun Wan Date: Thu, 11 May 2023 19:02:02 +0800 Subject: [PATCH 10/10] Consider common expression local input variables. Do not reuse the expression if CodegenContext changed. --- .../expressions/codegen/CodeGenerator.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bbe7b1cccd87f..a534c6f4f071c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1310,8 +1310,8 @@ class CodegenContext extends Logging { val (inputVars, _) = getLocalInputVariableValues(this, stats.expr, subExprEliminationExprs) val (initialized, isNull, value) = (stats.initialized.get, stats.isNull.get, stats.value.get) val validParamLength = isValidParamLength(calculateParamLengthFromExprValues(inputVars)) - if(!stats.addedFunction && validParamLength) { - // Generate the code for this expression tree and wrap it in a function. + if (!stats.addedFunction && validParamLength) { + // Wrap the expression code in a function. val argList = inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") val fn = @@ -1329,8 +1329,11 @@ class CodegenContext extends Logging { stats.params = Some(inputVars.map(_.javaType)) stats.addedFunction = true } - // input vars changed, e.g. some input vars now are GlobalValue. - if (inputVars.map(_.javaType) != stats.params.get) { + if (!classFunctions.values.map(_.keys).flatten.toSet.contains(stats.funcName.get)) { + // The CodegenContext has changed, all the corresponding variables will also not be available + eval + } else if (inputVars.map(_.javaType) != stats.params.get) { + // input vars changed, e.g. some input vars now are GlobalValue. eval } else { val code = @@ -1345,7 +1348,7 @@ class CodegenContext extends Logging { | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} - """.stripMargin + """.stripMargin } ExprCode(code, isNull, value) }