From cc622445cb09f50742ffcfe045e3bb802404046f Mon Sep 17 00:00:00 2001 From: huaxingao Date: Sun, 19 Jun 2022 22:53:52 -0700 Subject: [PATCH 01/10] Support runtime V2 filtering --- .../read/SupportsRuntimeV2Filtering.java | 60 +++++++++++++++ .../sql/catalyst/expressions/literals.scala | 1 + .../sql/connector/catalog/InMemoryTable.scala | 12 ++- .../catalog/InMemoryTableWithV2Filter.scala | 77 +++++++++++++++++++ .../InMemoryTableWithV2FilterCatalog.scala | 46 +++++++++++ .../datasources/v2/BatchScanExec.scala | 23 ++++-- .../dynamicpruning/PartitionPruning.scala | 9 ++- .../sql/DynamicPartitionPruningSuite.scala | 20 ++++- 8 files changed, 238 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java new file mode 100644 index 0000000000000..f23708dc68de5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can + * filter initially planned {@link InputPartition}s using predicates Spark infers at runtime. + *

+ * Note that Spark will push runtime filters only if they are beneficial. + * + * @since 3.4.0 + */ +@Experimental +public interface SupportsRuntimeV2Filtering extends Scan { + /** + * Returns attributes this scan can be filtered by at runtime. + *

+ * Spark will call {@link #filter(Predicate[])} if it can derive a runtime + * predicate for any of the filter attributes. + */ + NamedReference[] filterAttributes(); + + /** + * Filters this scan using runtime predicates. + *

+ * The provided expressions must be interpreted as a set of predicates that are ANDed together. + * Implementations may use the predicates to prune initially planned {@link InputPartition}s. + *

+ * If the scan also implements {@link SupportsReportPartitioning}, it must preserve + * the originally reported partitioning during runtime filtering. While applying runtime predicates, + * the scan may detect that some {@link InputPartition}s have no matching data. It can omit + * such partitions entirely only if it does not report a specific partitioning. Otherwise, + * the scan can replace the initially planned {@link InputPartition}s that have no matching + * data with empty {@link InputPartition}s but must preserve the overall number of partitions. + *

+ * Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime. + * + * @param predicates data source V2 predicates used to filter the scan at runtime + */ + void filter(Predicate[] predicates); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index a8c877a29de8a..d662b83eaf015 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -68,6 +68,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case s: UTF8String => Literal(s, StringType) case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case ac: Array[Char] => Literal(UTF8String.fromString(String.valueOf(ac)), StringType) case b: Boolean => Literal(b, BooleanType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 995c57c062e8a..3255dee0a16b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -268,12 +268,11 @@ class InMemoryTable( case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics - case class InMemoryBatchScan( + abstract class BatchScanBaseClass( var data: Seq[InputPartition], readSchema: StructType, tableSchema: StructType) - extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics - with SupportsReportPartitioning { + extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning { override def toBatch: Batch = this @@ -308,6 +307,13 @@ class InMemoryTable( val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name)) new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema) } + } + + case class InMemoryBatchScan( + var _data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType) + extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeFiltering { override def filterAttributes(): Array[NamedReference] = { val scanFields = readSchema.fields.map(_.name).toSet diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala new file mode 100644 index 0000000000000..896c2919c1476 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class InMemoryTableWithV2Filter( + name: String, + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]) + extends InMemoryTable(name, schema, partitioning, properties) { + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryV2FilterScanBuilder(schema) + } + + class InMemoryV2FilterScanBuilder(tableSchema: StructType) + extends InMemoryScanBuilder(tableSchema) { + override def build: Scan = + InMemoryV2FilterBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema) + } + + case class InMemoryV2FilterBatchScan( + var _data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType) + extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering { + + override def filterAttributes(): Array[NamedReference] = { + val scanFields = readSchema.fields.map(_.name).toSet + partitioning.flatMap(_.references) + .filter(ref => scanFields.contains(ref.fieldNames.mkString("."))) + } + + override def filter(filters: Array[Predicate]): Unit = { + if (partitioning.length == 1 && partitioning.head.references().length == 1) { + val ref = partitioning.head.references().head + filters.foreach { + case p : Predicate if p.name().equals("IN") => + if (p.children().length > 1) { + val filterRef = p.children()(0).asInstanceOf[FieldReference].references.head + if (filterRef.toString.equals(ref.toString)) { + val matchingKeys = + p.children().drop(1).map(_.asInstanceOf[LiteralValue[_]].value.toString).toSet + data = data.filter(partition => { + val key = partition.asInstanceOf[BufferedRows].keyString + matchingKeys.contains(key) + }) + } + } + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala new file mode 100644 index 0000000000000..08c1f65db290c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType + +class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog { + import CatalogV2Implicits._ + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryTableWithV2Filter(tableName, schema, partitions, properties) + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index ba969eb6ff1a3..f48bb375cd9f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -27,8 +27,10 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.InternalRowSet import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering, SupportsRuntimeV2Filtering} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.Filter /** * Physical plan node for scanning a batch of data from a data source v2. @@ -56,16 +58,27 @@ case class BatchScanExec( @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e) + case DynamicPruningExpression(e) => + if (scan.isInstanceOf[SupportsRuntimeFiltering]) { + DataSourceStrategy.translateRuntimeFilter(e) + } else if (scan.isInstanceOf[SupportsRuntimeV2Filtering]) { + DataSourceV2Strategy.translateRuntimeFilterV2(e) + } else { + None + } case _ => None } if (dataSourceFilters.nonEmpty) { val originalPartitioning = outputPartitioning - // the cast is safe as runtime filters are only assigned if the scan can be filtered - val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering] - filterableScan.filter(dataSourceFilters.toArray) + if (scan.isInstanceOf[SupportsRuntimeFiltering]) { + val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering] + filterableScan.filter(dataSourceFilters.map(_.asInstanceOf[Filter]).toArray) + } else if (scan.isInstanceOf[SupportsRuntimeV2Filtering]) { + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + filterableScan.filter(dataSourceFilters.map(_.asInstanceOf[Predicate]).toArray) + } // call toBatch again to get filtered partitions val newPartitions = scan.toBatch.planInputPartitions() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 61a243ddb3368..0ad8a8ca9d064 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering +import org.apache.spark.sql.connector.read.{SupportsRuntimeFiltering, SupportsRuntimeV2Filtering} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation @@ -85,6 +85,13 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join } else { None } + case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _)) => + val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) + if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { + Some(r) + } else { + None + } case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index 366120fb66c1a..fd213d120b6a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.GivenWhenThen import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.plans.ExistenceJoin -import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -1805,3 +1805,21 @@ class DynamicPartitionPruningV2SuiteAEOff extends DynamicPartitionPruningV2Suite class DynamicPartitionPruningV2SuiteAEOn extends DynamicPartitionPruningV2Suite with EnableAdaptiveExecutionSuite + +abstract class DynamicPartitionPruningV2FilterSuite + extends DynamicPartitionPruningDataSourceSuiteBase { + override protected def runAnalyzeColumnCommands: Boolean = false + + override protected def initState(): Unit = { + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableWithV2FilterCatalog].getName) + spark.conf.set("spark.sql.defaultCatalog", "testcat") + } +} + +class DynamicPartitionPruningV2FilterSuiteAEOff + extends DynamicPartitionPruningV2FilterSuite + with DisableAdaptiveExecutionSuite + +class DynamicPartitionPruningV2FilterSuiteAEOn + extends DynamicPartitionPruningV2FilterSuite + with EnableAdaptiveExecutionSuite From 25413ff540ebb893bd078b7d05a7f2661d9af761 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Sun, 19 Jun 2022 23:31:32 -0700 Subject: [PATCH 02/10] fix style --- .../connector/read/SupportsRuntimeV2Filtering.java | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java index f23708dc68de5..e276fe0764e10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -46,11 +46,12 @@ public interface SupportsRuntimeV2Filtering extends Scan { * Implementations may use the predicates to prune initially planned {@link InputPartition}s. *

* If the scan also implements {@link SupportsReportPartitioning}, it must preserve - * the originally reported partitioning during runtime filtering. While applying runtime predicates, - * the scan may detect that some {@link InputPartition}s have no matching data. It can omit - * such partitions entirely only if it does not report a specific partitioning. Otherwise, - * the scan can replace the initially planned {@link InputPartition}s that have no matching - * data with empty {@link InputPartition}s but must preserve the overall number of partitions. + * the originally reported partitioning during runtime filtering. While applying runtime + * predicates, the scan may detect that some {@link InputPartition}s have no matching data. It + * can omit such partitions entirely only if it does not report a specific partitioning. + * Otherwise, the scan can replace the initially planned {@link InputPartition}s that have no + * matching data with empty {@link InputPartition}s but must preserve the overall number of + * partitions. *

* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime. * From 66b437081572ffae9d7ef7e5c26557729bd7297a Mon Sep 17 00:00:00 2001 From: huaxingao Date: Tue, 19 Jul 2022 22:26:12 -0700 Subject: [PATCH 03/10] use pattern matching instead of if-else --- .../datasources/v2/BatchScanExec.scala | 25 +++++++------- .../datasources/v2/DataSourceV2Strategy.scala | 33 ++++++++++++++++--- .../dynamicpruning/PartitionPruning.scala | 2 +- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index f48bb375cd9f6..db15ab390fb27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -59,25 +59,24 @@ case class BatchScanExec( @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { case DynamicPruningExpression(e) => - if (scan.isInstanceOf[SupportsRuntimeFiltering]) { - DataSourceStrategy.translateRuntimeFilter(e) - } else if (scan.isInstanceOf[SupportsRuntimeV2Filtering]) { - DataSourceV2Strategy.translateRuntimeFilterV2(e) - } else { - None + scan match { + case _: SupportsRuntimeFiltering => + DataSourceStrategy.translateRuntimeFilter(e) + case _: SupportsRuntimeV2Filtering => + DataSourceV2Strategy.translateRuntimeFilterV2(e) + case _ => None } case _ => None } if (dataSourceFilters.nonEmpty) { val originalPartitioning = outputPartitioning - - if (scan.isInstanceOf[SupportsRuntimeFiltering]) { - val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering] - filterableScan.filter(dataSourceFilters.map(_.asInstanceOf[Filter]).toArray) - } else if (scan.isInstanceOf[SupportsRuntimeV2Filtering]) { - val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] - filterableScan.filter(dataSourceFilters.map(_.asInstanceOf[Predicate]).toArray) + scan match { + case s: SupportsRuntimeFiltering => + s.filter(dataSourceFilters.map(_.asInstanceOf[Filter]).toArray) + case s: SupportsRuntimeV2Filtering => + s.filter(dataSourceFilters.map(_.asInstanceOf[Predicate]).toArray) + case _ => } // call toBatch again to get filtered partitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 16c6b331d1093..c42703c7253c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,25 +20,26 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, Literal, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDelete, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnAndNestedColumn} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.sources.{BaseRelation, TableScan} @@ -498,7 +499,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } } -private[sql] object DataSourceV2Strategy { +private[sql] object DataSourceV2Strategy extends Logging { private def translateLeafNodeFilterV2(predicate: Expression): Option[Predicate] = { predicate match { @@ -582,6 +583,28 @@ private[sql] object DataSourceV2Strategy { throw new IllegalStateException("Failed to rebuild Expression for filter: " + predicate)) } } + + /** + * Translates a runtime filter into a data source v2 Predicate. + * + * Runtime filters usually contain a subquery that must be evaluated before the translation. + * If the underlying subquery hasn't completed yet, this method will throw an exception. + */ + protected[sql] def translateRuntimeFilterV2(expr: Expression): Option[Predicate] = expr match { + case in @ InSubqueryExec(PushableColumnAndNestedColumn(name), _, _, _, _, _) => + val values = in.values().getOrElse { + throw new IllegalStateException(s"Can't translate $in to v2 Predicate, no subquery result") + } + val literals = values.map { value => + val literal = Literal(value) + LiteralValue(literal.value, literal.dataType) + } + Some(new Predicate("IN", FieldReference(name) +: literals)) + + case other => + logWarning(s"Can't translate $other to source filter, unsupported expression") + None + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 0ad8a8ca9d064..269df2ea442ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -85,7 +85,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join } else { None } - case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _)) => + case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) => val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { Some(r) From 61570010a2a34d29a2b8cf706d09d4330167f63b Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 20 Jul 2022 14:10:07 -0700 Subject: [PATCH 04/10] address comments --- .../execution/datasources/v2/DataSourceV2Strategy.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c42703c7253c1..907f3be102c9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, Literal, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ @@ -595,10 +595,7 @@ private[sql] object DataSourceV2Strategy extends Logging { val values = in.values().getOrElse { throw new IllegalStateException(s"Can't translate $in to v2 Predicate, no subquery result") } - val literals = values.map { value => - val literal = Literal(value) - LiteralValue(literal.value, literal.dataType) - } + val literals = values.map(LiteralValue(_, in.child.dataType)) Some(new Predicate("IN", FieldReference(name) +: literals)) case other => From 1aadfc6d3a3bad8bc53d3c722668f1080d2eec8b Mon Sep 17 00:00:00 2001 From: huaxingao Date: Sat, 23 Jul 2022 11:25:36 -0700 Subject: [PATCH 05/10] address comments --- .../read/SupportsRuntimeV2Filtering.java | 11 ++++++++ .../datasources/v2/BatchScanExec.scala | 4 +-- .../dynamicpruning/PartitionPruning.scala | 27 ++++++++++--------- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java index e276fe0764e10..de7e2670e793b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -20,10 +20,21 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.sources.Filter; /** * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can * filter initially planned {@link InputPartition}s using predicates Spark infers at runtime. + * + * This interface is very similar to {@link SupportsRuntimeFiltering} except it uses + * data source V2 {@link Predicate} instead of data source V1 {@link Filter}. + * + * {@link SupportsRuntimeV2Filtering is preferred over {@link SupportsRuntimeFiltering} + * and only one of them should be implemented by the Data sources. However, if both of + * the interfaces are implemented, Spark will filter the planned + * {@link InputPartition}s first using {@link SupportsRuntimeV2Filtering#filter} + * and then using {@link SupportsRuntimeFiltering#filter}. + * *

* Note that Spark will push runtime filters only if they are beneficial. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index db15ab390fb27..cc03e4ff93888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -72,10 +72,10 @@ case class BatchScanExec( if (dataSourceFilters.nonEmpty) { val originalPartitioning = outputPartitioning scan match { - case s: SupportsRuntimeFiltering => - s.filter(dataSourceFilters.map(_.asInstanceOf[Filter]).toArray) case s: SupportsRuntimeV2Filtering => s.filter(dataSourceFilters.map(_.asInstanceOf[Predicate]).toArray) + case s: SupportsRuntimeFiltering => + s.filter(dataSourceFilters.map(_.asInstanceOf[Filter]).toArray) case _ => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 269df2ea442ae..2981bd36159ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.read.{SupportsRuntimeFiltering, SupportsRuntimeV2Filtering} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -79,23 +80,25 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join None } case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _, _)) => - val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) - if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { - Some(r) - } else { - None - } + getFilterableTableScan(resExp, scan.filterAttributes, r) case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) => - val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) - if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { - Some(r) - } else { - None - } + getFilterableTableScan(resExp, scan.filterAttributes, r) case _ => None } } + private def getFilterableTableScan( + resExp: Expression, + refs: Seq[NamedReference], + plan: LogicalPlan): Option[LogicalPlan] = { + val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](refs, plan) + if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { + Some(plan) + } else { + None + } + } + /** * Insert a dynamic partition pruning predicate on one side of the join using the filter on the * other side of the join. From 277222fe05cdabe04bfc1ba1d2d0cc908c1371ea Mon Sep 17 00:00:00 2001 From: huaxingao Date: Sat, 23 Jul 2022 14:49:41 -0700 Subject: [PATCH 06/10] fix java doc build failure --- .../spark/sql/connector/read/SupportsRuntimeV2Filtering.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java index de7e2670e793b..268afea43ee23 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -29,7 +29,7 @@ * This interface is very similar to {@link SupportsRuntimeFiltering} except it uses * data source V2 {@link Predicate} instead of data source V1 {@link Filter}. * - * {@link SupportsRuntimeV2Filtering is preferred over {@link SupportsRuntimeFiltering} + * {@link SupportsRuntimeV2Filtering} is preferred over {@link SupportsRuntimeFiltering} * and only one of them should be implemented by the Data sources. However, if both of * the interfaces are implemented, Spark will filter the planned * {@link InputPartition}s first using {@link SupportsRuntimeV2Filtering#filter} From 6221609e79c51140d22eeedb52bd06ab763b8702 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Tue, 26 Jul 2022 14:39:09 -0700 Subject: [PATCH 07/10] address comments --- .../read/SupportsRuntimeFiltering.java | 30 +++++++++++- .../read/SupportsRuntimeV2Filtering.java | 7 +-- .../sql/errors/QueryExecutionErrors.scala | 5 ++ .../spark/sql/util/PredicateUtils.scala | 49 +++++++++++++++++++ .../datasources/DataSourceStrategy.scala | 21 +------- .../datasources/v2/BatchScanExec.scala | 25 +++------- .../dynamicpruning/PartitionPruning.scala | 24 +++------ 7 files changed, 98 insertions(+), 63 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/util/PredicateUtils.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java index 65d029dc309b5..000ba56a82f36 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java @@ -19,7 +19,9 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.util.PredicateUtils; /** * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can @@ -30,7 +32,7 @@ * @since 3.2.0 */ @Experimental -public interface SupportsRuntimeFiltering extends Scan { +public interface SupportsRuntimeFiltering extends Scan, SupportsRuntimeV2Filtering { /** * Returns attributes this scan can be filtered by at runtime. *

@@ -57,4 +59,30 @@ public interface SupportsRuntimeFiltering extends Scan { * @param filters data source filters used to filter the scan at runtime */ void filter(Filter[] filters); + + /** + * Filters this scan using runtime predicates. + *

+ * The provided expressions must be interpreted as a set of predicates that are ANDed together. + * Implementations may use the predicates to prune initially planned {@link InputPartition}s. + *

+ * If the scan also implements {@link SupportsReportPartitioning}, it must preserve + * the originally reported partitioning during runtime filtering. While applying runtime + * predicates, the scan may detect that some {@link InputPartition}s have no matching data. It + * can omit such partitions entirely only if it does not report a specific partitioning. + * Otherwise, the scan can replace the initially planned {@link InputPartition}s that have no + * matching data with empty {@link InputPartition}s but must preserve the overall number of + * partitions. + *

+ * Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime. + * + * @param predicates data source V2 predicates used to filter the scan at runtime + */ + default void filter(Predicate[] predicates) { + Filter[] filters = new Filter[predicates.length]; + for (int i = 0; i < predicates.length; i++) { + filters[i] = PredicateUtils.toV1(predicates[i]).get(); + } + this.filter(filters); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java index 268afea43ee23..7c238bde969b2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -25,15 +25,10 @@ /** * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can * filter initially planned {@link InputPartition}s using predicates Spark infers at runtime. - * * This interface is very similar to {@link SupportsRuntimeFiltering} except it uses * data source V2 {@link Predicate} instead of data source V1 {@link Filter}. - * * {@link SupportsRuntimeV2Filtering} is preferred over {@link SupportsRuntimeFiltering} - * and only one of them should be implemented by the Data sources. However, if both of - * the interfaces are implemented, Spark will filter the planned - * {@link InputPartition}s first using {@link SupportsRuntimeV2Filtering#filter} - * and then using {@link SupportsRuntimeFiltering#filter}. + * and only one of them should be implemented by the data sources. * *

* Note that Spark will push runtime filters only if they are beneficial. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 64e6283c0e346..abdb4490df515 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2039,4 +2039,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLId(funcName), pattern)) } + + def unsupportedPredicateToFilterConversionError(predicateType: String): Throwable = { + new UnsupportedOperationException(s"conversion from data source v2 Predicate to data " + + s"source v1 Filter is not supported for this Predicate: ${predicateType}") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PredicateUtils.scala new file mode 100644 index 0000000000000..78bd33da76edc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PredicateUtils.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.sources.{Filter, In} + +private[sql] object PredicateUtils { + + def toV1(predicate: Predicate): Option[Filter] = { + predicate.name() match { + // Todo: add conversion for other V2 Predicate + case "IN" => + val attribute = predicate.children()(0) + .asInstanceOf[NamedReference].fieldNames().mkString(".") + val values = predicate.children().drop(1) + if (values.length > 0) { + val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType + assert(values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) + val inValues = values.map(v => + CatalystTypeConverters.convertToScala(v.asInstanceOf[LiteralValue[_]].value, dataType)) + Some(In(attribute, inValues)) + } else { + Some(In(attribute, Array.empty[Any])) + } + + case _ => + throw QueryExecutionErrors.unsupportedPredicateToFilterConversionError(predicate.name()) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7f30300a39c17..c9e6dd9630466 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation @@ -652,25 +652,6 @@ object DataSourceStrategy } } - /** - * Translates a runtime filter into a data source filter. - * - * Runtime filters usually contain a subquery that must be evaluated before the translation. - * If the underlying subquery hasn't completed yet, this method will throw an exception. - */ - protected[sql] def translateRuntimeFilter(expr: Expression): Option[Filter] = expr match { - case in @ InSubqueryExec(e @ PushableColumnAndNestedColumn(name), _, _, _, _, _) => - val values = in.values().getOrElse { - throw new IllegalStateException(s"Can't translate $in to source filter, no subquery result") - } - val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) - Some(sources.In(name, values.map(toScala))) - - case other => - logWarning(s"Can't translate $other to source filter, unsupported expression") - None - } - /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s * and can be handled by `relation`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index cc03e4ff93888..f1c43b8f60c96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -27,10 +27,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.InternalRowSet import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering, SupportsRuntimeV2Filtering} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeV2Filtering} /** * Physical plan node for scanning a batch of data from a data source v2. @@ -58,26 +55,16 @@ case class BatchScanExec( @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => - scan match { - case _: SupportsRuntimeFiltering => - DataSourceStrategy.translateRuntimeFilter(e) - case _: SupportsRuntimeV2Filtering => - DataSourceV2Strategy.translateRuntimeFilterV2(e) - case _ => None - } + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) case _ => None } if (dataSourceFilters.nonEmpty) { val originalPartitioning = outputPartitioning - scan match { - case s: SupportsRuntimeV2Filtering => - s.filter(dataSourceFilters.map(_.asInstanceOf[Predicate]).toArray) - case s: SupportsRuntimeFiltering => - s.filter(dataSourceFilters.map(_.asInstanceOf[Filter]).toArray) - case _ => - } + + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + filterableScan.filter(dataSourceFilters.toArray) // call toBatch again to get filtered partitions val newPartitions = scan.toBatch.planInputPartitions() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 2981bd36159ff..60ecc4b635e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -23,8 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.read.{SupportsRuntimeFiltering, SupportsRuntimeV2Filtering} +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation @@ -79,26 +78,17 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join } else { None } - case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _, _)) => - getFilterableTableScan(resExp, scan.filterAttributes, r) case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) => - getFilterableTableScan(resExp, scan.filterAttributes, r) + val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) + if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { + Some(r) + } else { + None + } case _ => None } } - private def getFilterableTableScan( - resExp: Expression, - refs: Seq[NamedReference], - plan: LogicalPlan): Option[LogicalPlan] = { - val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](refs, plan) - if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { - Some(plan) - } else { - None - } - } - /** * Insert a dynamic partition pruning predicate on one side of the join using the filter on the * other side of the join. From 9e0799f5c7473cd4a4c173fee5b9bc87d587c1cb Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 27 Jul 2022 08:32:47 -0700 Subject: [PATCH 08/10] address comments --- .../read/SupportsRuntimeFiltering.java | 37 ++++++++----------- .../sql/errors/QueryExecutionErrors.scala | 5 --- .../connector}/PredicateUtils.scala | 8 ++-- 3 files changed, 18 insertions(+), 32 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{util => internal/connector}/PredicateUtils.scala (87%) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java index 000ba56a82f36..8495aebc3eba4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java @@ -17,11 +17,16 @@ package org.apache.spark.sql.connector.read; +import java.util.ArrayList; +import java.util.List; + import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.util.PredicateUtils; +import org.apache.spark.sql.internal.connector.PredicateUtils; + +import scala.Option; /** * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can @@ -60,29 +65,17 @@ public interface SupportsRuntimeFiltering extends Scan, SupportsRuntimeV2Filteri */ void filter(Filter[] filters); - /** - * Filters this scan using runtime predicates. - *

- * The provided expressions must be interpreted as a set of predicates that are ANDed together. - * Implementations may use the predicates to prune initially planned {@link InputPartition}s. - *

- * If the scan also implements {@link SupportsReportPartitioning}, it must preserve - * the originally reported partitioning during runtime filtering. While applying runtime - * predicates, the scan may detect that some {@link InputPartition}s have no matching data. It - * can omit such partitions entirely only if it does not report a specific partitioning. - * Otherwise, the scan can replace the initially planned {@link InputPartition}s that have no - * matching data with empty {@link InputPartition}s but must preserve the overall number of - * partitions. - *

- * Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime. - * - * @param predicates data source V2 predicates used to filter the scan at runtime - */ default void filter(Predicate[] predicates) { - Filter[] filters = new Filter[predicates.length]; + List filterList = new ArrayList(); + for (int i = 0; i < predicates.length; i++) { - filters[i] = PredicateUtils.toV1(predicates[i]).get(); + Option filter = PredicateUtils.toV1(predicates[i]); + if (filter.nonEmpty()) { + filterList.add((Filter)filter.get()); + } } - this.filter(filters); + + Filter[] filters = new Filter[filterList.size()]; + this.filter(filterList.toArray(filters)); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index abdb4490df515..64e6283c0e346 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2039,9 +2039,4 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { toSQLId(funcName), pattern)) } - - def unsupportedPredicateToFilterConversionError(predicateType: String): Throwable = { - new UnsupportedOperationException(s"conversion from data source v2 Predicate to data " + - s"source v1 Filter is not supported for this Predicate: ${predicateType}") - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala similarity index 87% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/util/PredicateUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index 78bd33da76edc..cd4ed49aadd63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.util +package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.connector.expressions.filter. Predicate import org.apache.spark.sql.sources.{Filter, In} private[sql] object PredicateUtils { @@ -42,8 +41,7 @@ private[sql] object PredicateUtils { Some(In(attribute, Array.empty[Any])) } - case _ => - throw QueryExecutionErrors.unsupportedPredicateToFilterConversionError(predicate.name()) + case _ => None } } } From d6eb9316f8fecf7e721c42336ae50dd7ec959f85 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 27 Jul 2022 21:04:10 -0700 Subject: [PATCH 09/10] address comments --- .../spark/sql/internal/connector/PredicateUtils.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index cd4ed49aadd63..cc98d97b8109c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -19,21 +19,24 @@ package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter. Predicate +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.sources.{Filter, In} private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { predicate.name() match { - // Todo: add conversion for other V2 Predicate - case "IN" => + // TODO: add conversion for other V2 Predicate + case "IN" if (predicate.children()(0).isInstanceOf[NamedReference]) => val attribute = predicate.children()(0) .asInstanceOf[NamedReference].fieldNames().mkString(".") val values = predicate.children().drop(1) if (values.length > 0) { + if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType - assert(values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) + if (!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) { + return None + } val inValues = values.map(v => CatalystTypeConverters.convertToScala(v.asInstanceOf[LiteralValue[_]].value, dataType)) Some(In(attribute, inValues)) From d07d61b0fa33d4ce72c029368384f705ad2ecd68 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 27 Jul 2022 22:11:54 -0700 Subject: [PATCH 10/10] address comments --- .../apache/spark/sql/internal/connector/PredicateUtils.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index cc98d97b8109c..ace6b30d4ccec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -27,9 +27,8 @@ private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { predicate.name() match { // TODO: add conversion for other V2 Predicate - case "IN" if (predicate.children()(0).isInstanceOf[NamedReference]) => - val attribute = predicate.children()(0) - .asInstanceOf[NamedReference].fieldNames().mkString(".") + case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString val values = predicate.children().drop(1) if (values.length > 0) { if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None