diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 903a6fd7bd01..472b6e871e73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -197,7 +197,7 @@ class EquivalentExpressions { expr.exists(_.isInstanceOf[LambdaVariable]) || // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, // can cause error like NPE. - (expr.isInstanceOf[PlanExpression[_]] && Utils.isInRunningSparkTask) + (expr.exists(_.isInstanceOf[PlanExpression[_]]) && Utils.isInRunningSparkTask) if (!skip && !updateExprInMap(expr, map, useCount)) { val uc = useCount.signum diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index fa3003b27578..3c96ba430003 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -16,8 +16,9 @@ */ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DataType, Decimal, IntegerType} @@ -419,6 +420,21 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel } } + test("SPARK-38333: PlanExpression expression should skip addExprTree function in Executor") { + try { + // suppose we are in executor + val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null, cpus = 0) + TaskContext.setTaskContext(context1) + + val equivalence = new EquivalentExpressions + val expression = DynamicPruningExpression(Exists(LocalRelation())) + equivalence.addExprTree(expression) + assert(equivalence.getExprState(expression).isEmpty) + } finally { + TaskContext.unset() + } + } + test("SPARK-35886: PromotePrecision should not overwrite genCode") { val p = PromotePrecision(Literal(Decimal("10.1")))