From c3b41e243220dc39fff594fe80f3cb06521a7cec Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 13 Nov 2023 09:52:52 +0800 Subject: [PATCH] fix --- .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 49 ++++++------ .../datasources/v2/FileScanBuilder.scala | 2 +- .../datasources/v2/PushDownUtils.scala | 2 +- .../org/apache/spark/sql/QueryTest.scala | 19 +++++ .../datasources/DataSourceStrategySuite.scala | 74 ++++++++++++++++++- .../datasources/FileSourceStrategySuite.scala | 42 +++++++---- .../datasources/json/JsonSuite.scala | 37 +++++++++- .../parquet/ParquetFilterSuite.scala | 3 +- 9 files changed, 184 insertions(+), 46 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index b3b2b0eab055..56719e277943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -393,7 +393,7 @@ trait FileSourceScanLike extends DataSourceScanExec { scalarSubqueryReplaced.filterNot(_.references.exists { case FileSourceConstantMetadataAttribute(_) => true case _ => false - }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown, true)) } // This field may execute subquery expressions and should not be accessed during planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 94c2d2ffaca5..18dfae0eea8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -574,8 +574,10 @@ object DataSourceStrategy * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ protected[sql] def translateFilter( - predicate: Expression, supportNestedPredicatePushdown: Boolean): Option[Filter] = { - translateFilterWithMapping(predicate, None, supportNestedPredicatePushdown) + predicate: Expression, + nestedPredicatePushdownEnabled: Boolean, + canPartialPushDown: Boolean = false): Option[Filter] = { + translateFilterWithMapping(predicate, None, nestedPredicatePushdownEnabled, canPartialPushDown) } /** @@ -585,43 +587,44 @@ object DataSourceStrategy * @param translatedFilterToExpr An optional map from leaf node filter expressions to its * translated [[Filter]]. The map is used for rebuilding * [[Expression]] from [[Filter]]. - * @param nestedPredicatePushdownEnabled Whether nested predicate pushdown is enabled. + * @param nestedPredicatePushdownEnabled Whether nested predicate push down is enabled. + * @param canPartialPushDown Can it be translated into partial predicate. Note that + * PushDownUtils.pushFilters and DataSourceStrategy.selectFilters + * do not support partial push down. * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ protected[sql] def translateFilterWithMapping( predicate: Expression, translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]], - nestedPredicatePushdownEnabled: Boolean) + nestedPredicatePushdownEnabled: Boolean, + canPartialPushDown: Boolean) : Option[Filter] = { predicate match { case expressions.And(left, right) => - // See SPARK-12218 for detailed discussion - // It is not safe to just convert one side if we do not understand the - // other side. Here is an example used to explain the reason. - // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0) - // and we do not understand how to convert trim(b) = 'blah'. - // If we only convert a = 2, we will end up with - // (a = 2) OR (c > 0), which will generate wrong results. - // Pushing one leg of AND down is only safe to do at the top level. - // You can see ParquetFilters' createFilter for more details. - for { - leftFilter <- translateFilterWithMapping( - left, translatedFilterToExpr, nestedPredicatePushdownEnabled) - rightFilter <- translateFilterWithMapping( - right, translatedFilterToExpr, nestedPredicatePushdownEnabled) - } yield sources.And(leftFilter, rightFilter) + val translatedLeft = translateFilterWithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled, canPartialPushDown) + val translatedRight = translateFilterWithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled, canPartialPushDown) + if (canPartialPushDown) { + (translatedLeft ++ translatedRight).reduceOption(sources.And) + } else { + for { + leftFilter <- translatedLeft + rightFilter <- translatedRight + } yield sources.And(leftFilter, rightFilter) + } case expressions.Or(left, right) => for { leftFilter <- translateFilterWithMapping( - left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + left, translatedFilterToExpr, nestedPredicatePushdownEnabled, canPartialPushDown) rightFilter <- translateFilterWithMapping( - right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + right, translatedFilterToExpr, nestedPredicatePushdownEnabled, canPartialPushDown) } yield sources.Or(leftFilter, rightFilter) case expressions.Not(child) => - translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) - .map(sources.Not) + translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled, + canPartialPushDown = false).map(sources.Not) case other => val filter = translateLeafNodeFilter(other, PushableColumn(nestedPredicatePushdownEnabled)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 447a36fe622c..7df3850d9f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -77,7 +77,7 @@ abstract class FileScanBuilder( this.dataFilters = dataFilters val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] for (filterExpr <- dataFilters) { - val translated = DataSourceStrategy.translateFilter(filterExpr, true) + val translated = DataSourceStrategy.translateFilter(filterExpr, true, true) if (translated.nonEmpty) { translatedFilters += translated.get } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 3de4692c83b0..4ba899173e1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -54,7 +54,7 @@ object PushDownUtils { for (filterExpr <- filters) { val translated = DataSourceStrategy.translateFilterWithMapping(filterExpr, Some(translatedFilterToExpr), - nestedPredicatePushdownEnabled = true) + nestedPredicatePushdownEnabled = true, canPartialPushDown = false) if (translated.isEmpty) { untranslatableExprs += filterExpr } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f5ba655e3e85..75f0a880a43b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -25,10 +25,12 @@ import scala.jdk.CollectionConverters._ import org.scalatest.Assertions import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.functions.expr import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ @@ -252,6 +254,23 @@ abstract class QueryTest extends PlanTest { Pattern.quote( s"${cs.getClassName}.${cs.getMethodName}(${cs.getFileName}:${cs.getLineNumber + lines})") } + + /** + * Returns a set with all the filters present in the physical plan. + */ + def getPhysicalFilters(df: DataFrame): ExpressionSet = { + ExpressionSet( + df.queryExecution.executedPlan.collect { + case execution.FilterExec(f, _) => splitConjunctivePredicates(f) + }.flatten) + } + + /** + * Returns a resolved expression for `str` in the context of `df`. + */ + def resolve(df: DataFrame, str: String): Expression = { + df.select(expr(str)).queryExecution.analyzed.expressions.head.children.head + } } object QueryTest extends Assertions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 2b9ec97bace1..c0e29fd53be8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.AlwaysFalse import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -277,6 +278,71 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { // Functions such as 'Abs' are not supported EqualTo(Abs(attrInt), 6), IsNotNull(attrInt))), None) + + // (aint > 1 AND Abs(aint) < 10) OR (aint > 50 AND aint < 100) + testTranslateFilter(Or( + And( + GreaterThan(attrInt, 1), + // Functions such as 'Abs' are not supported + LessThan(Abs(attrInt), 10) + ), + And( + GreaterThan(attrInt, 50), + LessThan(attrInt, 100))), + Some(sources.Or( + sources.GreaterThan(intColName, 1), + sources.And( + sources.GreaterThan(intColName, 50), + sources.LessThan(intColName, 100)))), + canPartialPushDown = true) + + // (aint > 1 AND aint < 10) AND (Abs(aint) = 6 OR aint IS NOT NULL) + testTranslateFilter(And( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 6), + IsNotNull(attrInt))), + Some(sources.And( + sources.And( + sources.GreaterThan(intColName, 1), + sources.LessThan(intColName, 10)), + sources.IsNotNull(intColName))), + canPartialPushDown = true) + + // (aint > 1 OR aint < 10) AND (Abs(aint) = 6 OR aint IS NOT NULL) + testTranslateFilter(And( + Or( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + Or( + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 6), + IsNotNull(attrInt))), + Some( + sources.Or( + sources.GreaterThan(intColName, 1), + sources.LessThan(intColName, 10))), + canPartialPushDown = true) + + // Not((cint > 1) OR (Abs(cint) = 6 AND cint < 1)) + testTranslateFilter( + Not(Or( + GreaterThan(attrInt, 1), + And(EqualTo(Abs(attrInt), 6), LessThanOrEqual(attrInt, 1)))), + None, + canPartialPushDown = true) + + // (1 > 1 AND Abs(cint) > 1) OR (2 > 2 AND Abs(cint) > 2) + testTranslateFilter( + Or(And(1 > 1, EqualTo(Abs(attrInt), 1)), + And(2 > 2, EqualTo(Abs(attrInt), 2))), + Some(sources.Or(AlwaysFalse, AlwaysFalse)), + canPartialPushDown = true) }} test("SPARK-26865 DataSourceV2Strategy should push normalized filters") { @@ -319,14 +385,16 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { * Translate the given Catalyst [[Expression]] into data source [[sources.Filter]] * then verify against the given [[sources.Filter]]. */ - def testTranslateFilter(catalystFilter: Expression, result: Option[sources.Filter]): Unit = { + def testTranslateFilter( + catalystFilter: Expression, + result: Option[sources.Filter], + canPartialPushDown: Boolean = false): Unit = { assertResult(result) { - DataSourceStrategy.translateFilter(catalystFilter, true) + DataSourceStrategy.translateFilter(catalystFilter, true, canPartialPushDown) } } test("SPARK-41636: selectFilters returns predicates in deterministic order") { - val predicates = Seq(EqualTo($"id", 1), EqualTo($"id", 2), EqualTo($"id", 3), EqualTo($"id", 4), EqualTo($"id", 5), EqualTo($"id", 6)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 91182f6473d7..8a9cf57a2121 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -30,11 +30,9 @@ import org.apache.spark.paths.SparkPath.{fromUrlString => sp} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet} import org.apache.spark.sql.catalyst.util import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation -import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSparkSession @@ -625,6 +623,33 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession { } } + test("SPARK-44493: Push partial predicates are supported") { + def getPushedFilters(df: DataFrame): Option[String] = { + df.queryExecution.executedPlan.collectFirst { + case f: FileSourceScanExec => f.metadata.get("PushedFilters") + }.flatten + } + + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + // 'Abs' are not supported push down + val df = table.where("(c1 > 0 AND Abs(c2) > 1) OR (c2 > 1)") + assert(getPhysicalFilters(df) contains resolve(df, + "(c1 > 0 AND Abs(c2) > 1) OR (c2 > 1)")) + assert(getPushedFilters(df).contains("[Or(GreaterThan(c1,0),GreaterThan(c2,1))]")) + + assert(getPushedFilters( + table.where("Not(c1 <=> 0 AND Abs(c2) > 1) OR (c2 > 1)")).contains("[]")) + + assert(getPushedFilters( + table.where("Not((c1 > 0) OR (c2 > 1))")) + .contains("[IsNotNull(c1), IsNotNull(c2), LessThanOrEqual(c1,0), LessThanOrEqual(c2,1)]")) + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = @@ -646,19 +671,6 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession { } } - /** Returns a resolved expression for `str` in the context of `df`. */ - def resolve(df: DataFrame, str: String): Expression = { - df.select(expr(str)).queryExecution.analyzed.expressions.head.children.head - } - - /** Returns a set with all the filters present in the physical plan. */ - def getPhysicalFilters(df: DataFrame): ExpressionSet = { - ExpressionSet( - df.queryExecution.executedPlan.collect { - case execution.FilterExec(f, _) => splitConjunctivePredicates(f) - }.flatten) - } - /** Plans the query and calls the provided validation function with the planned partitioning. */ def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = { func(getFileScanRDD(df).filePartitions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index c7e4db2aa33e..56ad63047728 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -33,18 +33,21 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkConf, SparkException, SparkFileNotFoundException, SparkRuntimeException, SparkUpgradeException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal, StringTrim} import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, HadoopCompressionCodec} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLType import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, InMemoryFileIndex, NoopCache} -import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.json.{JsonScan, JsonScanBuilder} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.types.StructType.fromDDL import org.apache.spark.sql.types.TestUDT.{MyDenseVector, MyDenseVectorUDT} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -3744,6 +3747,38 @@ class JsonV2Suite extends JsonSuite { } } } + + test("SPARK-44493: Push partial predicates are supported") { + import testImplicits._ + withTempPath { path => + Seq( + """{"NAME": "fred", "THEID": 1}""", + s"""{"NAME": "mary", "THEID": 2}""", + s"""{"NAME": "joe 'foo' \\"bar\\"", "THEID": 3}""").toDF("data") + .repartition(1) + .write.text(path.getAbsolutePath) + withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") { + Seq("PERMISSIVE", "DROPMALFORMED", "FAILFAST").foreach { mode => + val df = spark.read + .option("mode", mode) + .schema("NAME string, THEID integer") + .json(path.getAbsolutePath) + .where($"THEID" > 0 && + Column(EqualTo(StringTrim($"NAME".expr), Literal(UTF8String.fromString("mary")))) || + ($"THEID" > 10)) + + assert(getPhysicalFilters(df) contains resolve(df, + "(THEID > 0 AND TRIM(NAME) = 'mary') OR (THEID > 10)")) + + val pushedFilters = df.queryExecution.executedPlan.collect { + case BatchScanExec(_, j: JsonScan, _, _, _, _) => j.pushedFilters + } + assert(pushedFilters.flatten.contains( + sources.Or(sources.GreaterThan("THEID", 0), sources.GreaterThan("THEID", 10)))) + } + } + } + } } class JsonLegacyTimeParserSuite extends JsonSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4ed5297ff4ea..2f775dbbb789 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -2254,7 +2254,8 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, scan: ParquetScan, _, _, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true)).toArray + val sourceFilters = + filters.flatMap(DataSourceStrategy.translateFilter(_, true, true)).toArray val pushedFilters = scan.pushedFilters assert(pushedFilters.nonEmpty, "No filter is pushed down") val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema)