From 028397179f0719a2b0ee20fc43dac3da485b4c2c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 18 Sep 2019 10:25:34 +0900 Subject: [PATCH 1/4] Fix --- .../sql/catalyst/catalog/interface.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 6 ++ .../expressions/codegen/CodeGenerator.scala | 79 +++++++++++-------- .../sql/catalyst/expressions/predicates.scala | 14 +++- 4 files changed, 66 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index f653bf41c162..c7757ebdb139 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -625,7 +625,7 @@ case class HiveTableRelation( storage = CatalogStorageFormat.empty, createTime = -1 ), - dataCols = dataCols.zipWithIndex.map { + dataCols = dataCols.zipWithIndex.map sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala{ case (attr, index) => attr.withExprId(ExprId(index)) }, partitionCols = partitionCols.zipWithIndex.map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4632957e7afd..f69a30b16989 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -156,6 +156,12 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too + // + // NOTE: We could use `CodeGenerator.defineIndependentFunction` here for the code path + // of the whole stage codegen. But, we don't do so now because the performance changes that + // we don't expect might occur in many queries. Therefore, we currently apply + // this split function to specific performance-sensitive places only, + // e.g., common subexpression elimination for the whole stage codegen and OR expressions. val splitThreshold = SQLConf.get.methodSplitThreshold if (eval.code.length > splitThreshold && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index acd3858431e6..25248e7c90de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1022,6 +1022,45 @@ class CodegenContext extends Logging { genCodes } + /** + * Defines an independent function for a given expression and returns a caller-side code + * for the function as `ExprCode`. + */ + def defineIndependentFunction( + expr: Expression, + ev: ExprCode, + funcNameOption: Option[String] = None, + inputVarsOption: Option[Seq[VariableValue]] = None): ExprCode = { + val inputVars = inputVarsOption.getOrElse(getLocalInputVariableValues(this, expr).toSeq) + if (isValidParamLength(calculateParamLengthFromExprValues(inputVars))) { + val (isNull, setIsNull) = if (!ev.isNull.isInstanceOf[LiteralValue]) { + val globalIsNull = addMutableState(JAVA_BOOLEAN, "globalIsNull") + (JavaCode.isNullGlobal(globalIsNull), s"$globalIsNull = ${ev.isNull};") + } else { + (ev.isNull, "") + } + + val fnName = freshName(funcNameOption.getOrElse(expr.prettyName)) + val argList = inputVars.map(v => s"${v.javaType.getName} ${v.variableName}") + val returnType = javaType(expr.dataType) + val funcFullName = addNewFunction(fnName, + s""" + |private $returnType $fnName(${argList.mkString(", ")}) { + | ${ev.code} + | $setIsNull + | return ${ev.value}; + |} + """.stripMargin) + + val newValue = freshName("value") + val inputVariables = inputVars.map(_.variableName).mkString(", ") + val code = code"$returnType $newValue = $funcFullName($inputVariables);" + ExprCode(code, isNull, JavaCode.variable(newValue, expr.dataType)) + } else { + ev + } + } + /** * Checks and sets up the state and codegen for subexpression elimination. This finds the * common subexpressions, generates the code snippets that evaluate those expressions and @@ -1057,39 +1096,15 @@ class CodegenContext extends Logging { } if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { commonExprs.zipWithIndex.map { case (exprs, i) => - val expr = exprs.head - val eval = commonExprVals(i) - - val isNullLiteral = eval.isNull match { - case TrueLiteral | FalseLiteral => true - case _ => false - } - val (isNull, isNullEvalCode) = if (!isNullLiteral) { - val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") - } else { - (eval.isNull, "") - } - - // Generate the code for this expression tree and wrap it in a function. - val fnName = freshName("subExpr") - val inputVars = inputVarsForAllFuncs(i) - val argList = inputVars.map(v => s"${v.javaType.getName} ${v.variableName}") - val returnType = javaType(expr.dataType) - val fn = - s""" - |private $returnType $fnName(${argList.mkString(", ")}) { - | ${eval.code} - | $isNullEvalCode - | return ${eval.value}; - |} - """.stripMargin - - val value = freshName("subExprValue") - val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType)) + val funcEval = defineIndependentFunction( + expr = exprs.head, + ev = commonExprVals(i), + funcNameOption = Some("subExprValue"), + inputVarsOption = Some(inputVarsForAllFuncs(i)) + ) + val state = SubExprEliminationState(funcEval.isNull, funcEval.value) exprs.foreach(localSubExprEliminationExprs.put(_, state)) - val inputVariables = inputVars.map(_.variableName).mkString(", ") - s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);" + funcEval.code.toString } } else { val errMsg = "Failed to split subexpression code into small functions because the " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bcd442ad3cc3..99c4d2abf1dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -629,9 +629,19 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } + // Splits predicate code if the generated code of `expr` is too long + private def genSplitCode(ctx: CodegenContext, expr: Expression): ExprCode = { + val eval = expr.genCode(ctx) + if (eval.code.length > SQLConf.get.methodSplitThreshold) { + ctx.defineIndependentFunction(this, eval) + } else { + eval + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) + val eval1 = genSplitCode(ctx, left) + val eval2 = genSplitCode(ctx, right) // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { From 996949497f0b221dd6ed13233952376e33324d2a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Sep 2019 12:51:46 +0900 Subject: [PATCH 2/4] Fix --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 4 ++-- .../apache/spark/sql/catalyst/expressions/predicates.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f69a30b16989..4ccd01ad34ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -157,7 +157,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too // - // NOTE: We could use `CodeGenerator.defineIndependentFunction` here for the code path + // NOTE: We could use `CodeGenerator.defineSingleSplitFunction` here for the code path // of the whole stage codegen. But, we don't do so now because the performance changes that // we don't expect might occur in many queries. Therefore, we currently apply // this split function to specific performance-sensitive places only, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 25248e7c90de..598a658f8a15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1026,7 +1026,7 @@ class CodegenContext extends Logging { * Defines an independent function for a given expression and returns a caller-side code * for the function as `ExprCode`. */ - def defineIndependentFunction( + def defineSingleSplitFunction( expr: Expression, ev: ExprCode, funcNameOption: Option[String] = None, @@ -1096,7 +1096,7 @@ class CodegenContext extends Logging { } if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { commonExprs.zipWithIndex.map { case (exprs, i) => - val funcEval = defineIndependentFunction( + val funcEval = defineSingleSplitFunction( expr = exprs.head, ev = commonExprVals(i), funcNameOption = Some("subExprValue"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 99c4d2abf1dc..2112b983df18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -633,7 +633,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P private def genSplitCode(ctx: CodegenContext, expr: Expression): ExprCode = { val eval = expr.genCode(ctx) if (eval.code.length > SQLConf.get.methodSplitThreshold) { - ctx.defineIndependentFunction(this, eval) + ctx.defineSingleSplitFunction(this, eval) } else { eval } From c5b203a8ff89dadddc8e5758ef80634c8ddfc243 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 21 Sep 2019 08:49:27 +0900 Subject: [PATCH 3/4] Fix --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 598a658f8a15..e6d533fcf845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1023,7 +1023,7 @@ class CodegenContext extends Logging { } /** - * Defines an independent function for a given expression and returns a caller-side code + * Defines an individual function for a given expression and returns a caller-side code * for the function as `ExprCode`. */ def defineSingleSplitFunction( From 543c0167dab23ece2e4db232c0fd7d4c9e5eeb8e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 26 Oct 2019 21:16:53 +0900 Subject: [PATCH 4/4] Fix --- .../org/apache/spark/sql/catalyst/catalog/interface.scala | 2 +- .../test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c7757ebdb139..f653bf41c162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -625,7 +625,7 @@ case class HiveTableRelation( storage = CatalogStorageFormat.empty, createTime = -1 ), - dataCols = dataCols.zipWithIndex.map sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala{ + dataCols = dataCols.zipWithIndex.map { case (attr, index) => attr.withExprId(ExprId(index)) }, partitionCols = partitionCols.zipWithIndex.map { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index aacb625d7921..9c16a6ab1f8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -82,11 +82,6 @@ class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSSchema { "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", "q63", "q65", "q68", "q73", "q79", "q89", "q98", "ss_max") - // List up the known queries having too large code in a generated function. - // A JIRA file for `modified-q3` is as follows; - // [SPARK-29128] Split predicate code in OR expressions - val blackListForMethodCodeSizeCheck = Set("modified-q3") - modifiedTPCDSQueries.foreach { name => val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) @@ -94,7 +89,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSSchema { test(testName) { // check the plans can be properly generated val plan = sql(queryString).queryExecution.executedPlan - checkGeneratedCode(plan, !blackListForMethodCodeSizeCheck.contains(testName)) + checkGeneratedCode(plan) } } }