diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4632957e7afd..4ccd01ad34ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -156,6 +156,12 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too + // + // NOTE: We could use `CodeGenerator.defineSingleSplitFunction` here for the code path + // of the whole stage codegen. But, we don't do so now because the performance changes that + // we don't expect might occur in many queries. Therefore, we currently apply + // this split function to specific performance-sensitive places only, + // e.g., common subexpression elimination for the whole stage codegen and OR expressions. val splitThreshold = SQLConf.get.methodSplitThreshold if (eval.code.length > splitThreshold && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index acd3858431e6..e6d533fcf845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1022,6 +1022,45 @@ class CodegenContext extends Logging { genCodes } + /** + * Defines an individual function for a given expression and returns a caller-side code + * for the function as `ExprCode`. + */ + def defineSingleSplitFunction( + expr: Expression, + ev: ExprCode, + funcNameOption: Option[String] = None, + inputVarsOption: Option[Seq[VariableValue]] = None): ExprCode = { + val inputVars = inputVarsOption.getOrElse(getLocalInputVariableValues(this, expr).toSeq) + if (isValidParamLength(calculateParamLengthFromExprValues(inputVars))) { + val (isNull, setIsNull) = if (!ev.isNull.isInstanceOf[LiteralValue]) { + val globalIsNull = addMutableState(JAVA_BOOLEAN, "globalIsNull") + (JavaCode.isNullGlobal(globalIsNull), s"$globalIsNull = ${ev.isNull};") + } else { + (ev.isNull, "") + } + + val fnName = freshName(funcNameOption.getOrElse(expr.prettyName)) + val argList = inputVars.map(v => s"${v.javaType.getName} ${v.variableName}") + val returnType = javaType(expr.dataType) + val funcFullName = addNewFunction(fnName, + s""" + |private $returnType $fnName(${argList.mkString(", ")}) { + | ${ev.code} + | $setIsNull + | return ${ev.value}; + |} + """.stripMargin) + + val newValue = freshName("value") + val inputVariables = inputVars.map(_.variableName).mkString(", ") + val code = code"$returnType $newValue = $funcFullName($inputVariables);" + ExprCode(code, isNull, JavaCode.variable(newValue, expr.dataType)) + } else { + ev + } + } + /** * Checks and sets up the state and codegen for subexpression elimination. This finds the * common subexpressions, generates the code snippets that evaluate those expressions and @@ -1057,39 +1096,15 @@ class CodegenContext extends Logging { } if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { commonExprs.zipWithIndex.map { case (exprs, i) => - val expr = exprs.head - val eval = commonExprVals(i) - - val isNullLiteral = eval.isNull match { - case TrueLiteral | FalseLiteral => true - case _ => false - } - val (isNull, isNullEvalCode) = if (!isNullLiteral) { - val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") - } else { - (eval.isNull, "") - } - - // Generate the code for this expression tree and wrap it in a function. - val fnName = freshName("subExpr") - val inputVars = inputVarsForAllFuncs(i) - val argList = inputVars.map(v => s"${v.javaType.getName} ${v.variableName}") - val returnType = javaType(expr.dataType) - val fn = - s""" - |private $returnType $fnName(${argList.mkString(", ")}) { - | ${eval.code} - | $isNullEvalCode - | return ${eval.value}; - |} - """.stripMargin - - val value = freshName("subExprValue") - val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType)) + val funcEval = defineSingleSplitFunction( + expr = exprs.head, + ev = commonExprVals(i), + funcNameOption = Some("subExprValue"), + inputVarsOption = Some(inputVarsForAllFuncs(i)) + ) + val state = SubExprEliminationState(funcEval.isNull, funcEval.value) exprs.foreach(localSubExprEliminationExprs.put(_, state)) - val inputVariables = inputVars.map(_.variableName).mkString(", ") - s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);" + funcEval.code.toString } } else { val errMsg = "Failed to split subexpression code into small functions because the " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bcd442ad3cc3..2112b983df18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -629,9 +629,19 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } + // Splits predicate code if the generated code of `expr` is too long + private def genSplitCode(ctx: CodegenContext, expr: Expression): ExprCode = { + val eval = expr.genCode(ctx) + if (eval.code.length > SQLConf.get.methodSplitThreshold) { + ctx.defineSingleSplitFunction(this, eval) + } else { + eval + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) + val eval1 = genSplitCode(ctx, left) + val eval2 = genSplitCode(ctx, right) // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index aacb625d7921..9c16a6ab1f8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -82,11 +82,6 @@ class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSSchema { "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", "q63", "q65", "q68", "q73", "q79", "q89", "q98", "ss_max") - // List up the known queries having too large code in a generated function. - // A JIRA file for `modified-q3` is as follows; - // [SPARK-29128] Split predicate code in OR expressions - val blackListForMethodCodeSizeCheck = Set("modified-q3") - modifiedTPCDSQueries.foreach { name => val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) @@ -94,7 +89,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSSchema { test(testName) { // check the plans can be properly generated val plan = sql(queryString).queryExecution.executedPlan - checkGeneratedCode(plan, !blackListForMethodCodeSizeCheck.contains(testName)) + checkGeneratedCode(plan) } } }