diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5b60de505f918..3db4b7701d5b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -117,7 +117,7 @@ trait PredicateHelper { e.children.isEmpty case a: AttributeReference => true // PythonUDF will be executed by dedicated physical operator later. - // For PythonUDFs that can't be evaluated in join condition, `PullOutPythonUDFInJoinCondition` + // For PythonUDFs that can't be evaluated in join condition, `ExtractPythonUDFFromJoinCondition` // will pull them out later. case _: PythonUDF => true case e: Unevaluable => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d9b39a858e33a..e885d9b613786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -183,10 +183,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :+ - Batch("Extract PythonUDF From JoinCondition", Once, - PullOutPythonUDFInJoinCondition) :+ - // The following batch should be executed after batch "Join Reorder" "LocalRelation" and - // "Extract PythonUDF From JoinCondition". + // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ Batch("RewriteSubquery", Once, @@ -225,7 +222,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: - PullOutPythonUDFInJoinCondition.ruleName :: NormalizeFloatingNumbers.ruleName :: Nil /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 22704b2d3cff8..b65221c236bfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -170,7 +170,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { * See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them * out from join condition. */ -object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper { +object ExtractPythonUDFFromJoinCondition extends Rule[LogicalPlan] with PredicateHelper { private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = { expr.find { e => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala similarity index 98% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala index 4a25ddf3ed9e9..03afdd1f8c364 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala @@ -28,12 +28,12 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types.{BooleanType, IntegerType} -class PullOutPythonUDFInJoinConditionSuite extends PlanTest { +class ExtractPythonUDFFromJoinConditionSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Extract PythonUDF From JoinCondition", Once, - PullOutPythonUDFInJoinCondition) :: + ExtractPythonUDFFromJoinCondition) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index d4fc92c9a26a1..a1135f7d6e6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog -import org.apache.spark.sql.catalyst.optimizer.{ColumnPruning, Optimizer, PushPredicateThroughNonJoin, RemoveNoopOperators} +import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} @@ -32,6 +32,10 @@ class SparkOptimizer( override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDFs", Once, + ExtractPythonUDFFromJoinCondition, + // `ExtractPythonUDFFromJoinCondition` can convert a join to a cartesian product. + // Here, we rerun cartesian product check. + CheckCartesianProducts, ExtractPythonUDFFromAggregate, // This must be executed after `ExtractPythonUDFFromAggregate` and before `ExtractPythonUDFs`. ExtractGroupingPythonUDFFromAggregate, @@ -47,6 +51,7 @@ class SparkOptimizer( Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+ + ExtractPythonUDFFromJoinCondition.ruleName :+ ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+ ExtractPythonUDFs.ruleName