diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDFFolding.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDFFolding.scala new file mode 100644 index 000000000000..0bb08da8ba9f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DeterministicLiteralUDFFolding.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Literal, ScalaUDF} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf + +/** + * If the UDF is deterministic and if the children are all literal, we can replace the udf + * with the output of the udf serialized + */ +object DeterministicLiteralUDFFolding extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = + if (!SQLConf.get.deterministicLiteralUdfFoldingEnabled) { + plan + } else plan transformAllExpressions { + case udf @ ScalaUDF(_, dataType, children, _, _, _, _, _) + if udf.deterministic && children.forall(_.isInstanceOf[Literal]) => + val res = udf.eval(null) + Literal(res, dataType) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5b59ac7d2a9b..db0f35c02e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -78,6 +78,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CombineLimits, CombineUnions, // Constant folding and strength reduction + DeterministicLiteralUDFFolding, TransposeWindow, NullPropagation, ConstantPropagation, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 71c830207701..7a1bb66e436d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1777,6 +1777,17 @@ object SQLConf { .doc("When true, the upcast will be loose and allows string to atomic types.") .booleanConf .createWithDefault(false) + + val DETERMINISTIC_LITERAL_UDF_FOLDING_ENABLED = + buildConf("spark.sql.deterministic.literal.udf.folding.enabled") + .doc("When true, it will enable the optimization for a UDF that is deterministic and the " + + "inputs are all literals. When your inputs to the UDF are all literal and UDF is " + + "deterministic, we can optimize this to evaluate the UDF once and use the output " + + "instead of evaluating the UDF each time for every row in the query." + + "Ensure that your UDFs are correctly setup with respect to whether they are " + + "deterministic or not, before enabling this.") + .booleanConf + .createWithDefault(false) } /** @@ -2235,6 +2246,10 @@ class SQLConf extends Serializable with Logging { def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG) + def deterministicLiteralUdfFoldingEnabled: Boolean = + getConf(DETERMINISTIC_LITERAL_UDF_FOLDING_ENABLED) + + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala new file mode 100644 index 000000000000..d9a3ed024792 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeterministicLiteralUDFFoldingSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class DeterministicLiteralUDFFoldingSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("Deterministic and literal UDF optimization") { + def udfNodesCount(plan: LogicalPlan): Int = { + plan.expressions.head.children.collect({ + case f: ScalaUDF => f + }).length + } + + val foo = udf(() => Math.random()).asNondeterministic() + spark.udf.register("random0", foo) + assert(!foo.deterministic) + val foo2 = udf((x: String, i: Int) => x.length + i) + spark.udf.register("mystrlen", foo2) + assert(foo2.deterministic) + + Seq(("true", (1, 0, 0, 1)), ("false", (1, 1, 1, 1))).foreach { case (flag, expectedCounts) => + withSQLConf(SQLConf.DETERMINISTIC_LITERAL_UDF_FOLDING_ENABLED.key -> flag) { + // Non deterministic + val plan = sql("SELECT random0()").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == expectedCounts._1) + + // udf is deterministic and args are literal + assert(sql("SELECT mystrlen('abc', 1)").head().getInt(0) == 4) + val plan2 = sql("SELECT mystrlen('abc', 1)").queryExecution.optimizedPlan + assert(udfNodesCount(plan2) == expectedCounts._2) + val plan3 = sql("SELECT mystrlen('abc', mystrlen('c', 1))").queryExecution.optimizedPlan + assert(udfNodesCount(plan3) == expectedCounts._3) + + // udf is deterministic and args are not literal + withTempView("temp1") { + val df = sparkContext.parallelize( + (1 to 10).map(i => i.toString)).toDF("i1") + df.createOrReplaceTempView("temp1") + val plan = sql("SELECT mystrlen(i1, 1) FROM temp1").queryExecution.optimizedPlan + assert(udfNodesCount(plan) == expectedCounts._4) + } + } + } + } + + test("udf folding rule in join") { + withTempView("temp1") { + val df = sparkContext.parallelize((1 to 5).map(i => i.toString)).toDF("i1") + df.createOrReplaceTempView("temp1") + val foo = udf((x: String, i: Int) => x.length + i) + spark.udf.register("mystrlen1", foo) + assert(foo.deterministic) + + val query = "SELECT mystrlen1(i1, 1) FROM temp1, " + + "(SELECT mystrlen1('abc', mystrlen1('c', 1)) AS ref) WHERE mystrlen1(i1, ref) > 1" + assert(sql(query).count() == 5) + + withSQLConf(SQLConf.DETERMINISTIC_LITERAL_UDF_FOLDING_ENABLED.key -> "true") { + val exception = intercept[AnalysisException] { + sql(query).count() + } + assert(exception.message.startsWith("Detected implicit cartesian product")) + + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assert(sql(query).count() == 5) + } + } + } + } +}