diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 807765c1e00a..48a5fb14707c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -357,15 +357,19 @@ class CodegenContext { /** * It will count the lines of every Java function generated by whole-stage codegen, - * if there is a function of length greater than spark.sql.codegen.maxLinesPerFunction, - * it will return true. + * if there is a function of length greater than `maxLinesPerFunction`, it will return true. + * If `maxLinesPerFunction` has -1, it will always return false. */ def isTooLongGeneratedFunction: Boolean = { - classFunctions.values.exists { _.values.exists { - code => - val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code) - codeWithoutComments.count(_ == '\n') > SQLConf.get.maxLinesPerFunction + val maxLinesPerFunc = SQLConf.get.maxLinesPerFunction + if (maxLinesPerFunc >= 0) { + classFunctions.values.exists { _.values.exists { code => + val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code) + codeWithoutComments.count(_ == '\n') > maxLinesPerFunc + } } + } else { + false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a685099505ee..4c3ae81abe3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -578,8 +578,10 @@ object SQLConf { "When the generated function exceeds this threshold, " + "the whole-stage codegen is deactivated for this subtree of the current query plan. " + "The default value 4000 is the max length of byte code JIT supported " + - "for a single function(8000) divided by 2.") + "for a single function(8000) divided by 2. Use -1 to disable this.") .intConf + .checkValue(maxLines => maxLines >= -1, "The maximum must not be a negative integer, " + + "except for -1 using to always activate whole-stage codegen") .createWithDefault(4000) val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index beeee6a97c8d..30ed836967b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{Column, Dataset, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec @@ -193,11 +191,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } } - test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") { + test("SPARK-21603 check there is not a too long generated function when threshold is Max/-1") { withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) { val ctx = genGroupByCodeGenContext(30) assert(ctx.isTooLongGeneratedFunction === false) } + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "-1") { + val ctx = genGroupByCodeGenContext(30) + assert(ctx.isTooLongGeneratedFunction === false) + } } test("SPARK-21603 check there is a too long generated function when threshold is 0") { @@ -206,4 +208,12 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { assert(ctx.isTooLongGeneratedFunction === true) } } + + test("SPARK-21603 `maxLinesPerFunction` must not be negative") { + val errMsg = intercept[IllegalArgumentException] { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "-2") {} + }.getMessage + assert(errMsg.contains("The maximum must not be a negative integer, except for -1 using to " + + "always activate whole-stage codegen")) + } }