diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 4e790b1dd3f36..415ce46788119 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.planning -import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ @@ -35,7 +33,9 @@ trait OperationHelper { }) protected def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = { - expr.transform { + // use transformUp instead of transformDown to avoid dead loop + // in case of there's Alias whose exprId is the same as its child attribute. + expr.transformUp { case a @ Alias(ref: AttributeReference, name) => aliases.get(ref) .map(Alias(_, name)(a.exprId, a.qualifier)) @@ -142,12 +142,14 @@ object ScanOperation extends OperationHelper with PredicateHelper { case Filter(condition, child) => collectProjectsAndFilters(child) match { case Some((fields, filters, other, aliases)) => - // Follow CombineFilters and only keep going if the collected Filters - // are all deterministic and this filter doesn't have common non-deterministic - // expressions with lower Project. - if (filters.forall(_.deterministic) && - !hasCommonNonDeterministic(Seq(condition), aliases)) { - val substitutedCondition = substitute(aliases)(condition) + // Follow CombineFilters and only keep going if 1) the collected Filters + // and this filter are all deterministic or 2) if this filter is the first + // collected filter and doesn't have common non-deterministic expressions + // with lower Project. + val substitutedCondition = substitute(aliases)(condition) + val canCombineFilters = (filters.nonEmpty && filters.forall(_.deterministic) && + substitutedCondition.deterministic) || filters.isEmpty + if (canCombineFilters && !hasCommonNonDeterministic(Seq(condition), aliases)) { Some((fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 8385db9f78653..f45495121a980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -147,7 +147,8 @@ object FileSourceStrategy extends Strategy with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, l.output) + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(_.deterministic), l.output) val partitionColumns = l.resolve( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 237717a3ad196..f242f75f39f20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -116,8 +116,8 @@ class PrunedScanSuite extends DataSourceTest with SharedSparkSession { testPruning("SELECT a FROM oneToTenPruned", "a") testPruning("SELECT b FROM oneToTenPruned", "b") testPruning("SELECT a, rand() FROM oneToTenPruned WHERE a > 5", "a") - testPruning("SELECT a FROM oneToTenPruned WHERE rand() > 5", "a") - testPruning("SELECT a, rand() FROM oneToTenPruned WHERE rand() > 5", "a") + testPruning("SELECT a FROM oneToTenPruned WHERE rand() > 0.5", "a") + testPruning("SELECT a, rand() FROM oneToTenPruned WHERE rand() > 0.5", "a") testPruning("SELECT a, rand() FROM oneToTenPruned WHERE b > 5", "a", "b") def testPruning(sqlString: String, expectedColumns: String*): Unit = {