From efb22422c64b5c3cb5d56222ad59eb2ea89e20ef Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 5 Feb 2019 15:01:55 -0800 Subject: [PATCH 01/12] Add optimizer rule to evaluate deterministic literal udf once Add tests and fix existing test to cover whatever codepath it was covering earlier --- .../sql/catalyst/optimizer/Optimizer.scala | 15 +++++++++++ .../spark/sql/FileBasedDataSourceSuite.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 25 ++++++++++++++++++- .../sql/hive/orc/HiveOrcSourceSuite.scala | 3 ++- 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5b59ac7d2a9b..b3324d7fc2ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -78,6 +78,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CombineLimits, CombineUnions, // Constant folding and strength reduction + DeterministicLiteralUDF, TransposeWindow, NullPropagation, ConstantPropagation, @@ -1721,3 +1722,17 @@ object OptimizeLimitZero extends Rule[LogicalPlan] { empty(ll) } } + +/** + * If the UDF is deterministic and if the children are all literal, we can replace the udf + * with the output of the udf serialized + */ +object DeterministicLiteralUDF extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case udf@ScalaUDF(_, dataType, children, _, _, _, _, udfDeterministic) + if udf.deterministic && children.forall(_.isInstanceOf[Literal]) => { + val res = udf.eval(null) + Literal(res, dataType) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index dd11b5c50398..ca0191a202ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -399,7 +399,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo .contains(errorMessage(format))) msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new NullData()) + spark.udf.register("testType", udf(() => new NullData()).asNondeterministic()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index f2a71bd628bd..a2bafad78baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql import java.math.BigDecimal import org.apache.spark.sql.api.java._ -import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} @@ -151,6 +152,28 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } + test("Deterministic and literal UDF optimization") { + def udfNodesCount(plan: LogicalPlan): Int = { + plan.expressions.head.children.collect({ + case f: ScalaUDF => f + }).length + } + + // Non deterministic + val foo = udf(() => Math.random()) + spark.udf.register("random0", foo.asNondeterministic()) + val plan = sql("SELECT random0()").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == 1) + + // udf is deterministic and args are literal + val foo2 = udf((x: String, i: Int) => x.length + i) + spark.udf.register("mystrlen", foo2) + assert(foo2.deterministic) + assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) === 4) + val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan + assert(udfNodesCount(plan2) == 0) + } + test("UDF in a WHERE") { withTempView("integerData") { spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 6bcb2225e66d..8f6f79f1e168 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT} import org.apache.spark.sql.execution.datasources.orc.OrcSuite +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ @@ -124,7 +125,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { assert(msg.contains("ORC data source does not support null data type.")) msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new IntervalData()) + spark.udf.register("testType", udf(() => new IntervalData()).asNondeterministic()) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage assert(msg.contains("ORC data source does not support calendarinterval data type.")) From 3e83db953befd05c5bef583831f5db6a7b8d9986 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 9 May 2019 16:18:51 -0700 Subject: [PATCH 02/12] Add a test --- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index a2bafad78baf..5cd74ff4f5fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -169,9 +169,18 @@ class UDFSuite extends QueryTest with SharedSQLContext { val foo2 = udf((x: String, i: Int) => x.length + i) spark.udf.register("mystrlen", foo2) assert(foo2.deterministic) - assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) === 4) + assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan assert(udfNodesCount(plan2) == 0) + + // udf is deterministic and args are not literal + withTempView("temp1") { + val df = sparkContext.parallelize( + (1 to 10).map(i => i.toString)).toDF("i1") + df.createOrReplaceTempView("temp1") + val plan = sql("SELECT mystrlen(i1, 1) FROM temp1").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == 1) + } } test("UDF in a WHERE") { From 0cfa744d4f9ab729b5d75d35136a99fed83f8670 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 13 May 2019 11:05:33 -0700 Subject: [PATCH 03/12] Add nested testcase --- sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 5cd74ff4f5fe..5ed4bfc6664f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -172,6 +172,8 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan assert(udfNodesCount(plan2) == 0) + val plan3 = sql("SELECT mystrlen('abc', mystrlen('c', 1))").queryExecution.optimizedPlan + assert(udfNodesCount(plan3) == 0) // udf is deterministic and args are not literal withTempView("temp1") { From 6d87778653db7958047f2e0623c4bcf2d741b8f2 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 13 May 2019 14:40:32 -0700 Subject: [PATCH 04/12] fix the avro test --- .../src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 49aa21884f8b..31aff085b4dc 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.{IntervalData, NullData, NullUDT} import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -1003,7 +1004,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(msg.contains("Cannot save interval data type into external storage.")) msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new IntervalData()) + spark.udf.register("testType", udf(() => new IntervalData()).asNondeterministic()) sql("select testType()").write.format("avro").mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) From 76512141b65584be7661a2699d13ed65e621488c Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 14 May 2019 10:42:04 -0700 Subject: [PATCH 05/12] Fix ALSSuite tests to capture the error message that is in the Exception --- .../apache/spark/ml/recommendation/ALSSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 5ba39284f63b..f71e8c1fa1de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -231,42 +231,42 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(1231000000000L))).collect() } - assert(e.getMessage.contains(msg)) + assert(e.getCause().getMessage.contains(msg)) } withClue("Invalid Decimal: out of range") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 2)))).collect() } - assert(e.getMessage.contains(msg)) + assert(e.getCause().getMessage.contains(msg)) } withClue("Invalid Decimal: fractional part") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect() } - assert(e.getMessage.contains(msg)) + assert(e.getCause().getMessage.contains(msg)) } withClue("Invalid Double: out of range") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(1231000000000.0))).collect() } - assert(e.getMessage.contains(msg)) + assert(e.getCause().getMessage.contains(msg)) } withClue("Invalid Double: fractional part") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(123.1))).collect() } - assert(e.getMessage.contains(msg)) + assert(e.getCause().getMessage.contains(msg)) } withClue("Invalid Type") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit("123.1"))).collect() } - assert(e.getMessage.contains("was not numeric")) + assert(e.getCause().getMessage.contains("was not numeric")) } } From 98154a2ef11405cc3d77717abc36fa8055a30c83 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 16 May 2019 15:58:51 -0700 Subject: [PATCH 06/12] Add comments --- .../test/scala/org/apache/spark/sql/avro/AvroSuite.scala | 5 +++++ .../org/apache/spark/sql/FileBasedDataSourceSuite.scala | 6 ++++++ .../org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala | 6 ++++++ 3 files changed, 17 insertions(+) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 31aff085b4dc..a70ee12f3ee8 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1004,6 +1004,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(msg.contains("Cannot save interval data type into external storage.")) msg = intercept[AnalysisException] { + // Here the IntervalData and corresponding UDT does not implement the serialize method + // so evaluation of udf will throw error. We are testing the error codepath for datasource + // SPARK-27692 optimizes evaluation of deterministic UDF with literal inputs and will + // evaluate the UDF in the optimizer. To bypass this optimization and to test + // the error codepath for datasource, mark the udf as non deterministic spark.udf.register("testType", udf(() => new IntervalData()).asNondeterministic()) sql("select testType()").write.format("avro").mode("overwrite").save(tempDir) }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index ca0191a202ad..3c63a585a7d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -399,6 +399,12 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo .contains(errorMessage(format))) msg = intercept[AnalysisException] { + // Here the NullData and corresponding UDT does not implement the serialize + // method so evaluation of udf will throw error. We are testing the error codepath + // for datasource. + // SPARK-27692 optimizes evaluation of deterministic UDF that has literal inputs and + // will evaluate the UDF in the optimizer. To bypass this optimization and to test + // the error codepath for datasource, mark the udf as non deterministic spark.udf.register("testType", udf(() => new NullData()).asNondeterministic()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 8f6f79f1e168..f0e2f12ef5ba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -125,6 +125,12 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { assert(msg.contains("ORC data source does not support null data type.")) msg = intercept[AnalysisException] { + // Here the IntervalData and corresponding UDT does not implement the serialize method + // so evaluation of udf will throw error. We are testing the error codepath + // for datasource. + // SPARK-27692 optimizes evaluation of deterministic UDF that has literal inputs and will + // evaluate the UDF in the optimizer. To bypass this optimization and to test + // the error codepath for datasource, mark the udf as non deterministic spark.udf.register("testType", udf(() => new IntervalData()).asNondeterministic()) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage From b344b6c94985cf106bbd90faefa9bd10ecb42b53 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 16 May 2019 18:49:37 -0700 Subject: [PATCH 07/12] Added the optimization under a conf property that defaults to false --- .../org/apache/spark/sql/avro/AvroSuite.scala | 7 +-- .../spark/ml/recommendation/ALSSuite.scala | 12 ++--- .../sql/catalyst/optimizer/Optimizer.scala | 15 +++--- .../apache/spark/sql/internal/SQLConf.scala | 12 +++++ .../spark/sql/FileBasedDataSourceSuite.scala | 8 +--- .../scala/org/apache/spark/sql/UDFSuite.scala | 46 ++++++++++--------- .../sql/hive/orc/HiveOrcSourceSuite.scala | 8 +--- 7 files changed, 54 insertions(+), 54 deletions(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index a70ee12f3ee8..ba824eb7ae68 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1004,12 +1004,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(msg.contains("Cannot save interval data type into external storage.")) msg = intercept[AnalysisException] { - // Here the IntervalData and corresponding UDT does not implement the serialize method - // so evaluation of udf will throw error. We are testing the error codepath for datasource - // SPARK-27692 optimizes evaluation of deterministic UDF with literal inputs and will - // evaluate the UDF in the optimizer. To bypass this optimization and to test - // the error codepath for datasource, mark the udf as non deterministic - spark.udf.register("testType", udf(() => new IntervalData()).asNondeterministic()) + spark.udf.register("testType", udf(() => new IntervalData())) sql("select testType()").write.format("avro").mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index f71e8c1fa1de..5ba39284f63b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -231,42 +231,42 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(1231000000000L))).collect() } - assert(e.getCause().getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } withClue("Invalid Decimal: out of range") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 2)))).collect() } - assert(e.getCause().getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } withClue("Invalid Decimal: fractional part") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect() } - assert(e.getCause().getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } withClue("Invalid Double: out of range") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(1231000000000.0))).collect() } - assert(e.getCause().getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } withClue("Invalid Double: fractional part") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit(123.1))).collect() } - assert(e.getCause().getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } withClue("Invalid Type") { val e: SparkException = intercept[SparkException] { df.select(checkedCast(lit("123.1"))).collect() } - assert(e.getCause().getMessage.contains("was not numeric")) + assert(e.getMessage.contains("was not numeric")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b3324d7fc2ca..74b6e43b8800 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1728,11 +1728,14 @@ object OptimizeLimitZero extends Rule[LogicalPlan] { * with the output of the udf serialized */ object DeterministicLiteralUDF extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case udf@ScalaUDF(_, dataType, children, _, _, _, _, udfDeterministic) - if udf.deterministic && children.forall(_.isInstanceOf[Literal]) => { - val res = udf.eval(null) - Literal(res, dataType) + def apply(plan: LogicalPlan): LogicalPlan = + if (!SQLConf.get.deterministicUdfFoldEnabled) { + plan + } else plan transformAllExpressions { + case udf @ ScalaUDF(_, dataType, children, _, _, _, _, udfDeterministic) + if udf.deterministic && children.forall(_.isInstanceOf[Literal]) => { + val res = udf.eval(null) + Literal(res, dataType) + } } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 71c830207701..16e59b3c3ac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -150,6 +150,16 @@ object SQLConf { } } + val UDF_DETERMINISTIC_FOLD_ENABLED = buildConf("spark.deterministic.udf.folding.enabled") + .doc("When true, it will enable the optimization for a UDF that is deterministic and the " + + "inputs are all literals. When your inputs to the UDF are all literal and UDF is " + + "deterministic, we can optimize this to evaluate the UDF once and use the output " + + "instead of evaluating the UDF each time for every row in the query." + + "Ensure that your UDFs are correctly setup with respect to whether they are " + + "deterministic or not, before enabling this.") + .booleanConf + .createWithDefault(false) + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + "specified by their rule names and separated by comma. It is not guaranteed that all the " + @@ -2139,6 +2149,8 @@ class SQLConf extends Serializable with Logging { def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) + def deterministicUdfFoldEnabled: Boolean = getConf(UDF_DETERMINISTIC_FOLD_ENABLED) + def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 3c63a585a7d2..4ed24bca3fce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -399,13 +399,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo .contains(errorMessage(format))) msg = intercept[AnalysisException] { - // Here the NullData and corresponding UDT does not implement the serialize - // method so evaluation of udf will throw error. We are testing the error codepath - // for datasource. - // SPARK-27692 optimizes evaluation of deterministic UDF that has literal inputs and - // will evaluate the UDF in the optimizer. To bypass this optimization and to test - // the error codepath for datasource, mark the udf as non deterministic - spark.udf.register("testType", udf(() => new NullData()).asNondeterministic()) + spark.udf.register("testType", udf(() => new NullData())) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 5ed4bfc6664f..816149762e3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -159,29 +159,31 @@ class UDFSuite extends QueryTest with SharedSQLContext { }).length } - // Non deterministic - val foo = udf(() => Math.random()) - spark.udf.register("random0", foo.asNondeterministic()) - val plan = sql("SELECT random0()").queryExecution.optimizedPlan - assert(udfNodesCount(plan) == 1) - - // udf is deterministic and args are literal - val foo2 = udf((x: String, i: Int) => x.length + i) - spark.udf.register("mystrlen", foo2) - assert(foo2.deterministic) - assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) - val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan - assert(udfNodesCount(plan2) == 0) - val plan3 = sql("SELECT mystrlen('abc', mystrlen('c', 1))").queryExecution.optimizedPlan - assert(udfNodesCount(plan3) == 0) - - // udf is deterministic and args are not literal - withTempView("temp1") { - val df = sparkContext.parallelize( - (1 to 10).map(i => i.toString)).toDF("i1") - df.createOrReplaceTempView("temp1") - val plan = sql("SELECT mystrlen(i1, 1) FROM temp1").queryExecution.optimizedPlan + withSQLConf(SQLConf.UDF_DETERMINISTIC_FOLD_ENABLED.key -> "true") { + // Non deterministic + val foo = udf(() => Math.random()) + spark.udf.register("random0", foo.asNondeterministic()) + val plan = sql("SELECT random0()").queryExecution.optimizedPlan assert(udfNodesCount(plan) == 1) + + // udf is deterministic and args are literal + val foo2 = udf((x: String, i: Int) => x.length + i) + spark.udf.register("mystrlen", foo2) + assert(foo2.deterministic) + assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) + val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan + assert(udfNodesCount(plan2) == 0) + val plan3 = sql("SELECT mystrlen('abc', mystrlen('c', 1))").queryExecution.optimizedPlan + assert(udfNodesCount(plan3) == 0) + + // udf is deterministic and args are not literal + withTempView("temp1") { + val df = sparkContext.parallelize( + (1 to 10).map(i => i.toString)).toDF("i1") + df.createOrReplaceTempView("temp1") + val plan = sql("SELECT mystrlen(i1, 1) FROM temp1").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == 1) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index f0e2f12ef5ba..8888e4873bae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -125,13 +125,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { assert(msg.contains("ORC data source does not support null data type.")) msg = intercept[AnalysisException] { - // Here the IntervalData and corresponding UDT does not implement the serialize method - // so evaluation of udf will throw error. We are testing the error codepath - // for datasource. - // SPARK-27692 optimizes evaluation of deterministic UDF that has literal inputs and will - // evaluate the UDF in the optimizer. To bypass this optimization and to test - // the error codepath for datasource, mark the udf as non deterministic - spark.udf.register("testType", udf(() => new IntervalData()).asNondeterministic()) + spark.udf.register("testType", udf(() => new IntervalData())) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage assert(msg.contains("ORC data source does not support calendarinterval data type.")) From b348cd5833018a5beae90f702eb7592f166388e9 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 16 May 2019 19:39:16 -0700 Subject: [PATCH 08/12] small changes to config name and remove diffs to other tests --- .../src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala | 3 +-- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- .../scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | 2 +- sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 2 +- .../org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala | 3 +-- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index ba824eb7ae68..49aa21884f8b 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -37,7 +37,6 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.{IntervalData, NullData, NullUDT} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -1004,7 +1003,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(msg.contains("Cannot save interval data type into external storage.")) msg = intercept[AnalysisException] { - spark.udf.register("testType", udf(() => new IntervalData())) + spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.format("avro").mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 16e59b3c3ac0..c64a1e35c40f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -150,7 +150,7 @@ object SQLConf { } } - val UDF_DETERMINISTIC_FOLD_ENABLED = buildConf("spark.deterministic.udf.folding.enabled") + val DETERMINISTIC_UDF_FOLD_ENABLED = buildConf("spark.deterministic.udf.folding.enabled") .doc("When true, it will enable the optimization for a UDF that is deterministic and the " + "inputs are all literals. When your inputs to the UDF are all literal and UDF is " + "deterministic, we can optimize this to evaluate the UDF once and use the output " + @@ -2149,7 +2149,7 @@ class SQLConf extends Serializable with Logging { def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) - def deterministicUdfFoldEnabled: Boolean = getConf(UDF_DETERMINISTIC_FOLD_ENABLED) + def deterministicUdfFoldEnabled: Boolean = getConf(DETERMINISTIC_UDF_FOLD_ENABLED) def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 4ed24bca3fce..dd11b5c50398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -399,7 +399,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo .contains(errorMessage(format))) msg = intercept[AnalysisException] { - spark.udf.register("testType", udf(() => new NullData())) + spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 816149762e3e..216a20dee1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -159,7 +159,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { }).length } - withSQLConf(SQLConf.UDF_DETERMINISTIC_FOLD_ENABLED.key -> "true") { + withSQLConf(SQLConf.DETERMINISTIC_UDF_FOLD_ENABLED.key -> "true") { // Non deterministic val foo = udf(() => Math.random()) spark.udf.register("random0", foo.asNondeterministic()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 8888e4873bae..6bcb2225e66d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT} import org.apache.spark.sql.execution.datasources.orc.OrcSuite -import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ @@ -125,7 +124,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { assert(msg.contains("ORC data source does not support null data type.")) msg = intercept[AnalysisException] { - spark.udf.register("testType", udf(() => new IntervalData())) + spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage assert(msg.contains("ORC data source does not support calendarinterval data type.")) From 7d54787e5d0f89f3fb715a40a0d9e49d7681d933 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 20 May 2019 16:03:47 -0700 Subject: [PATCH 09/12] Move rule and test to new file respectively - review comment --- .../optimizer/DeterministicLiteralUDF.scala | 40 ++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 17 ----- .../spark/sql/UDFOptimizationsSuite.scala | 63 +++++++++++++++++++ .../scala/org/apache/spark/sql/UDFSuite.scala | 38 +---------- 4 files changed, 104 insertions(+), 54 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UDFOptimizationsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala new file mode 100644 index 000000000000..f2aa27dd33db --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Literal, ScalaUDF} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf + +/** + * If the UDF is deterministic and if the children are all literal, we can replace the udf + * with the output of the udf serialized + */ +object DeterministicLiteralUDF extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = + if (!SQLConf.get.deterministicUdfFoldEnabled) { + plan + } else plan transformAllExpressions { + case udf @ ScalaUDF(_, dataType, children, _, _, _, _, udfDeterministic) + if udf.deterministic && children.forall(_.isInstanceOf[Literal]) => { + val res = udf.eval(null) + Literal(res, dataType) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 74b6e43b8800..e1e22e5e9ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1722,20 +1722,3 @@ object OptimizeLimitZero extends Rule[LogicalPlan] { empty(ll) } } - -/** - * If the UDF is deterministic and if the children are all literal, we can replace the udf - * with the output of the udf serialized - */ -object DeterministicLiteralUDF extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = - if (!SQLConf.get.deterministicUdfFoldEnabled) { - plan - } else plan transformAllExpressions { - case udf @ ScalaUDF(_, dataType, children, _, _, _, _, udfDeterministic) - if udf.deterministic && children.forall(_.isInstanceOf[Literal]) => { - val res = udf.eval(null) - Literal(res, dataType) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFOptimizationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFOptimizationsSuite.scala new file mode 100644 index 000000000000..8865311248cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFOptimizationsSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class UDFOptimizationsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("Deterministic and literal UDF optimization") { + def udfNodesCount(plan: LogicalPlan): Int = { + plan.expressions.head.children.collect({ + case f: ScalaUDF => f + }).length + } + + withSQLConf(SQLConf.DETERMINISTIC_UDF_FOLD_ENABLED.key -> "true") { + // Non deterministic + val foo = udf(() => Math.random()) + spark.udf.register("random0", foo.asNondeterministic()) + val plan = sql("SELECT random0()").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == 1) + + // udf is deterministic and args are literal + val foo2 = udf((x: String, i: Int) => x.length + i) + spark.udf.register("mystrlen", foo2) + assert(foo2.deterministic) + assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) + val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan + assert(udfNodesCount(plan2) == 0) + val plan3 = sql("SELECT mystrlen('abc', mystrlen('c', 1))").queryExecution.optimizedPlan + assert(udfNodesCount(plan3) == 0) + + // udf is deterministic and args are not literal + withTempView("temp1") { + val df = sparkContext.parallelize( + (1 to 10).map(i => i.toString)).toDF("i1") + df.createOrReplaceTempView("temp1") + val plan = sql("SELECT mystrlen(i1, 1) FROM temp1").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == 1) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 216a20dee1a7..f2a71bd628bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql import java.math.BigDecimal import org.apache.spark.sql.api.java._ -import org.apache.spark.sql.catalyst.expressions.ScalaUDF -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} @@ -152,41 +151,6 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } - test("Deterministic and literal UDF optimization") { - def udfNodesCount(plan: LogicalPlan): Int = { - plan.expressions.head.children.collect({ - case f: ScalaUDF => f - }).length - } - - withSQLConf(SQLConf.DETERMINISTIC_UDF_FOLD_ENABLED.key -> "true") { - // Non deterministic - val foo = udf(() => Math.random()) - spark.udf.register("random0", foo.asNondeterministic()) - val plan = sql("SELECT random0()").queryExecution.optimizedPlan - assert(udfNodesCount(plan) == 1) - - // udf is deterministic and args are literal - val foo2 = udf((x: String, i: Int) => x.length + i) - spark.udf.register("mystrlen", foo2) - assert(foo2.deterministic) - assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) - val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan - assert(udfNodesCount(plan2) == 0) - val plan3 = sql("SELECT mystrlen('abc', mystrlen('c', 1))").queryExecution.optimizedPlan - assert(udfNodesCount(plan3) == 0) - - // udf is deterministic and args are not literal - withTempView("temp1") { - val df = sparkContext.parallelize( - (1 to 10).map(i => i.toString)).toDF("i1") - df.createOrReplaceTempView("temp1") - val plan = sql("SELECT mystrlen(i1, 1) FROM temp1").queryExecution.optimizedPlan - assert(udfNodesCount(plan) == 1) - } - } - } - test("UDF in a WHERE") { withTempView("integerData") { spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) From ee5fa4ed49ad26a4f0a84034c14621b266f2a3f1 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 28 May 2019 14:39:07 -0700 Subject: [PATCH 10/12] rename --- .../optimizer/DeterministicLiteralUDF.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 26 ++++++++++--------- ...te.scala => DeterministicLiteralUDF.scala} | 4 +-- 3 files changed, 17 insertions(+), 15 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/{UDFOptimizationsSuite.scala => DeterministicLiteralUDF.scala} (93%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala index f2aa27dd33db..bfeac2c1cecd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf */ object DeterministicLiteralUDF extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = - if (!SQLConf.get.deterministicUdfFoldEnabled) { + if (!SQLConf.get.deterministicLiteralUdfFoldEnabled) { plan } else plan transformAllExpressions { case udf @ ScalaUDF(_, dataType, children, _, _, _, _, udfDeterministic) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c64a1e35c40f..ea100e668531 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -150,16 +150,6 @@ object SQLConf { } } - val DETERMINISTIC_UDF_FOLD_ENABLED = buildConf("spark.deterministic.udf.folding.enabled") - .doc("When true, it will enable the optimization for a UDF that is deterministic and the " + - "inputs are all literals. When your inputs to the UDF are all literal and UDF is " + - "deterministic, we can optimize this to evaluate the UDF once and use the output " + - "instead of evaluating the UDF each time for every row in the query." + - "Ensure that your UDFs are correctly setup with respect to whether they are " + - "deterministic or not, before enabling this.") - .booleanConf - .createWithDefault(false) - val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + "specified by their rule names and separated by comma. It is not guaranteed that all the " + @@ -1787,6 +1777,17 @@ object SQLConf { .doc("When true, the upcast will be loose and allows string to atomic types.") .booleanConf .createWithDefault(false) + + val DETERMINISTIC_LITERAL_UDF_FOLD_ENABLED = + buildConf("spark.sql.deterministic.literal.udf.folding.enabled") + .doc("When true, it will enable the optimization for a UDF that is deterministic and the " + + "inputs are all literals. When your inputs to the UDF are all literal and UDF is " + + "deterministic, we can optimize this to evaluate the UDF once and use the output " + + "instead of evaluating the UDF each time for every row in the query." + + "Ensure that your UDFs are correctly setup with respect to whether they are " + + "deterministic or not, before enabling this.") + .booleanConf + .createWithDefault(false) } /** @@ -2149,8 +2150,6 @@ class SQLConf extends Serializable with Logging { def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) - def deterministicUdfFoldEnabled: Boolean = getConf(DETERMINISTIC_UDF_FOLD_ENABLED) - def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) @@ -2247,6 +2246,9 @@ class SQLConf extends Serializable with Logging { def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG) + def deterministicLiteralUdfFoldEnabled: Boolean = getConf(DETERMINISTIC_LITERAL_UDF_FOLD_ENABLED) + + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFOptimizationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDF.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/UDFOptimizationsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDF.scala index 8865311248cb..7af53d8c5b9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFOptimizationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDF.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.functions.udf import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -class UDFOptimizationsSuite extends QueryTest with SharedSQLContext { +class DeterministicLiteralUDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("Deterministic and literal UDF optimization") { @@ -33,7 +33,7 @@ class UDFOptimizationsSuite extends QueryTest with SharedSQLContext { }).length } - withSQLConf(SQLConf.DETERMINISTIC_UDF_FOLD_ENABLED.key -> "true") { + withSQLConf(SQLConf.DETERMINISTIC_LITERAL_UDF_FOLD_ENABLED.key -> "true") { // Non deterministic val foo = udf(() => Math.random()) spark.udf.register("random0", foo.asNondeterministic()) From 7b98cad8fead165dd4c7d8ee64a4ba0d1e4826d3 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 30 May 2019 14:36:48 -0700 Subject: [PATCH 11/12] renames.. and add more tests --- ...a => DeterministicLiteralUDFFolding.scala} | 9 ++- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../spark/sql/DeterministicLiteralUDF.scala | 63 ----------------- .../DeterministicLiteralUDFFoldingSuite.scala | 67 +++++++++++++++++++ 5 files changed, 75 insertions(+), 71 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/{DeterministicLiteralUDF.scala => DeterministicLiteralUDFFolding.scala} (82%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDF.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDFFolding.scala similarity index 82% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDFFolding.scala index bfeac2c1cecd..0bb08da8ba9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDFFolding.scala @@ -26,15 +26,14 @@ import org.apache.spark.sql.internal.SQLConf * If the UDF is deterministic and if the children are all literal, we can replace the udf * with the output of the udf serialized */ -object DeterministicLiteralUDF extends Rule[LogicalPlan] { +object DeterministicLiteralUDFFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = - if (!SQLConf.get.deterministicLiteralUdfFoldEnabled) { + if (!SQLConf.get.deterministicLiteralUdfFoldingEnabled) { plan } else plan transformAllExpressions { - case udf @ ScalaUDF(_, dataType, children, _, _, _, _, udfDeterministic) - if udf.deterministic && children.forall(_.isInstanceOf[Literal]) => { + case udf @ ScalaUDF(_, dataType, children, _, _, _, _, _) + if udf.deterministic && children.forall(_.isInstanceOf[Literal]) => val res = udf.eval(null) Literal(res, dataType) - } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e1e22e5e9ff9..db0f35c02e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -78,7 +78,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CombineLimits, CombineUnions, // Constant folding and strength reduction - DeterministicLiteralUDF, + DeterministicLiteralUDFFolding, TransposeWindow, NullPropagation, ConstantPropagation, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ea100e668531..7a1bb66e436d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1778,7 +1778,7 @@ object SQLConf { .booleanConf .createWithDefault(false) - val DETERMINISTIC_LITERAL_UDF_FOLD_ENABLED = + val DETERMINISTIC_LITERAL_UDF_FOLDING_ENABLED = buildConf("spark.sql.deterministic.literal.udf.folding.enabled") .doc("When true, it will enable the optimization for a UDF that is deterministic and the " + "inputs are all literals. When your inputs to the UDF are all literal and UDF is " + @@ -2246,7 +2246,8 @@ class SQLConf extends Serializable with Logging { def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG) - def deterministicLiteralUdfFoldEnabled: Boolean = getConf(DETERMINISTIC_LITERAL_UDF_FOLD_ENABLED) + def deterministicLiteralUdfFoldingEnabled: Boolean = + getConf(DETERMINISTIC_LITERAL_UDF_FOLDING_ENABLED) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDF.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDF.scala deleted file mode 100644 index 7af53d8c5b9e..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDF.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.expressions.ScalaUDF -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext - -class DeterministicLiteralUDFSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - test("Deterministic and literal UDF optimization") { - def udfNodesCount(plan: LogicalPlan): Int = { - plan.expressions.head.children.collect({ - case f: ScalaUDF => f - }).length - } - - withSQLConf(SQLConf.DETERMINISTIC_LITERAL_UDF_FOLD_ENABLED.key -> "true") { - // Non deterministic - val foo = udf(() => Math.random()) - spark.udf.register("random0", foo.asNondeterministic()) - val plan = sql("SELECT random0()").queryExecution.optimizedPlan - assert(udfNodesCount(plan) == 1) - - // udf is deterministic and args are literal - val foo2 = udf((x: String, i: Int) => x.length + i) - spark.udf.register("mystrlen", foo2) - assert(foo2.deterministic) - assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) - val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan - assert(udfNodesCount(plan2) == 0) - val plan3 = sql("SELECT mystrlen('abc', mystrlen('c', 1))").queryExecution.optimizedPlan - assert(udfNodesCount(plan3) == 0) - - // udf is deterministic and args are not literal - withTempView("temp1") { - val df = sparkContext.parallelize( - (1 to 10).map(i => i.toString)).toDF("i1") - df.createOrReplaceTempView("temp1") - val plan = sql("SELECT mystrlen(i1, 1) FROM temp1").queryExecution.optimizedPlan - assert(udfNodesCount(plan) == 1) - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala new file mode 100644 index 000000000000..65e6c4cf6ed4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class DeterministicLiteralUDFFoldingSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("Deterministic and literal UDF optimization") { + def udfNodesCount(plan: LogicalPlan): Int = { + plan.expressions.head.children.collect({ + case f: ScalaUDF => f + }).length + } + + val foo = udf(() => Math.random()).asNondeterministic() + spark.udf.register("random0", foo) + assert(!foo.deterministic) + val foo2 = udf((x: String, i: Int) => x.length + i) + spark.udf.register("mystrlen", foo2) + assert(foo2.deterministic) + + Seq(("true", (1, 0, 0, 1)), ("false", (1, 1, 1, 1))).foreach { case (flag, expectedCounts) => + withSQLConf(SQLConf.DETERMINISTIC_LITERAL_UDF_FOLDING_ENABLED.key -> flag) { + // Non deterministic + val plan = sql("SELECT random0()").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == expectedCounts._1) + + // udf is deterministic and args are literal + assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) + val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan + assert(udfNodesCount(plan2) == expectedCounts._2) + val plan3 = sql("SELECT mystrlen('abc', mystrlen('c', 1))").queryExecution.optimizedPlan + assert(udfNodesCount(plan3) == expectedCounts._3) + + // udf is deterministic and args are not literal + withTempView("temp1") { + val df = sparkContext.parallelize( + (1 to 10).map(i => i.toString)).toDF("i1") + df.createOrReplaceTempView("temp1") + val plan = sql("SELECT mystrlen(i1, 1) FROM temp1").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == expectedCounts._4) + } + } + } + } +} From 93241b30eeb071d575142b26db36c40cad5b93b6 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 11 Jun 2019 16:02:37 -0700 Subject: [PATCH 12/12] add join testcase --- .../DeterministicLiteralUDFFoldingSuite.scala | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala index 65e6c4cf6ed4..d9a3ed024792 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala @@ -64,4 +64,29 @@ class DeterministicLiteralUDFFoldingSuite extends QueryTest with SharedSQLContex } } } + + test("udf folding rule in join") { + withTempView("temp1") { + val df = sparkContext.parallelize((1 to 5).map(i => i.toString)).toDF("i1") + df.createOrReplaceTempView("temp1") + val foo = udf((x: String, i: Int) => x.length + i) + spark.udf.register("mystrlen1", foo) + assert(foo.deterministic) + + val query = "SELECT mystrlen1(i1, 1) FROM temp1, " + + "(SELECT mystrlen1('abc', mystrlen1('c', 1)) AS ref) WHERE mystrlen1(i1, ref) > 1" + assert(sql(query).count() == 5) + + withSQLConf(SQLConf.DETERMINISTIC_LITERAL_UDF_FOLDING_ENABLED.key -> "true") { + val exception = intercept[AnalysisException] { + sql(query).count() + } + assert(exception.message.startsWith("Detected implicit cartesian product")) + + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assert(sql(query).count() == 5) + } + } + } + } }