diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 62782f6051b8..efcf607b5890 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -99,7 +99,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() - val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + val aggregate = ColumnPruning(Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan)) if (!canBroadcastBySize(aggregate, conf)) { // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, // i.e., the semi-join will be a shuffled join, which is not worthwhile. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index 0e016e19a628..fda442eeef03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} -import org.apache.spark.sql.catalyst.optimizer.MergeScalarSubqueries +import org.apache.spark.sql.catalyst.optimizer.{ColumnPruning, MergeScalarSubqueries} import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan} import org.apache.spark.sql.execution.{ReusedSubqueryExec, SubqueryExec} @@ -257,6 +257,11 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled)) ensureLeftSemiJoinExists(planEnabled) assert(normalizedEnabled != normalizedDisabled) + val agg = planEnabled.collect { + case Join(_, agg: Aggregate, LeftSemi, _, _) => agg + } + assert(agg.size == 1) + assert(agg.head.fastEquals(ColumnPruning(agg.head))) } else { comparePlans(planDisabled, planEnabled) }