Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.Predicate;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface if they can
* filter initially planned {@link InputPartition}s using predicates Spark infers at runtime.
* <p>
* Note that Spark will push runtime filters only if they are beneficial.
*
* @since 3.4.0
*/
@Experimental
public interface SupportsRuntimeV2Filtering extends Scan {
/**
* Returns attributes this scan can be filtered by at runtime.
* <p>
* Spark will call {@link #filter(Predicate[])} if it can derive a runtime
* predicate for any of the filter attributes.
*/
NamedReference[] filterAttributes();

/**
* Filters this scan using runtime predicates.
* <p>
* The provided expressions must be interpreted as a set of predicates that are ANDed together.
* Implementations may use the predicates to prune initially planned {@link InputPartition}s.
* <p>
* If the scan also implements {@link SupportsReportPartitioning}, it must preserve
* the originally reported partitioning during runtime filtering. While applying runtime
* predicates, the scan may detect that some {@link InputPartition}s have no matching data. It
* can omit such partitions entirely only if it does not report a specific partitioning.
* Otherwise, the scan can replace the initially planned {@link InputPartition}s that have no
* matching data with empty {@link InputPartition}s but must preserve the overall number of
* partitions.
* <p>
* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime.
*
* @param predicates data source V2 predicates used to filter the scan at runtime
*/
void filter(Predicate[] predicates);
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ object Literal {
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
case s: String => Literal(UTF8String.fromString(s), StringType)
case s: UTF8String => Literal(s, StringType)
case c: Char => Literal(UTF8String.fromString(c.toString), StringType)
case ac: Array[Char] => Literal(UTF8String.fromString(String.valueOf(ac)), StringType)
case b: Boolean => Literal(b, BooleanType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,11 @@ class InMemoryTable(

case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics

case class InMemoryBatchScan(
abstract class BatchScanBaseClass(
var data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics
with SupportsReportPartitioning {
extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning {

override def toBatch: Batch = this

Expand Down Expand Up @@ -308,6 +307,13 @@ class InMemoryTable(
val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name))
new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema)
}
}

case class InMemoryBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeFiltering {

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog

import java.util

import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class InMemoryTableWithV2Filter(
name: String,
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String])
extends InMemoryTable(name, schema, partitioning, properties) {

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryV2FilterScanBuilder(schema)
}

class InMemoryV2FilterScanBuilder(tableSchema: StructType)
extends InMemoryScanBuilder(tableSchema) {
override def build: Scan =
InMemoryV2FilterBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema)
}

case class InMemoryV2FilterBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering {

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
.filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
}

override def filter(filters: Array[Predicate]): Unit = {
if (partitioning.length == 1 && partitioning.head.references().length == 1) {
val ref = partitioning.head.references().head
filters.foreach {
case p : Predicate if p.name().equals("IN") =>

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like some unapply method to extract what you want is more preferable

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Predicate is a java class. I don't think unapply can be used

if (p.children().length > 1) {
val filterRef = p.children()(0).asInstanceOf[FieldReference].references.head
if (filterRef.toString.equals(ref.toString)) {
val matchingKeys =
p.children().drop(1).map(_.asInstanceOf[LiteralValue[_]].value.toString).toSet
data = data.filter(partition => {
val key = partition.asInstanceOf[BufferedRows].keyString
matchingKeys.contains(key)
})
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog

import java.util

import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType

class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog {
import CatalogV2Implicits._

override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident)
}

InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)

val tableName = s"$name.${ident.quoted}"
val table = new InMemoryTableWithV2Filter(tableName, schema, partitions, properties)
tables.put(ident, table)
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.InternalRowSet
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources.Filter

/**
* Physical plan node for scanning a batch of data from a data source v2.
Expand Down Expand Up @@ -56,16 +58,26 @@ case class BatchScanExec(

@transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = {
val dataSourceFilters = runtimeFilters.flatMap {
case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e)
case DynamicPruningExpression(e) =>
scan match {
case _: SupportsRuntimeFiltering =>
DataSourceStrategy.translateRuntimeFilter(e)
case _: SupportsRuntimeV2Filtering =>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we make SupportsRuntimeV2Filtering have higher priority over SupportsRuntimeFiltering? Also we need to document the behavior if a source implements both of them

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem to me that a data source would implement both SupportsRuntimeV2Filtering and SupportsRuntimeFiltering?

DataSourceV2Strategy.translateRuntimeFilterV2(e)
case _ => None
}
case _ => None
}

if (dataSourceFilters.nonEmpty) {
val originalPartitioning = outputPartitioning

// the cast is safe as runtime filters are only assigned if the scan can be filtered
val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering]
filterableScan.filter(dataSourceFilters.toArray)
scan match {
case s: SupportsRuntimeFiltering =>
s.filter(dataSourceFilters.map(_.asInstanceOf[Filter]).toArray)
case s: SupportsRuntimeV2Filtering =>
s.filter(dataSourceFilters.map(_.asInstanceOf[Predicate]).toArray)
case _ =>
}

// call toBatch again to get filtered partitions
val newPartitions = scan.toBatch.planInputPartitions()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable}
import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, Literal, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, ResolveDefaultColumns, V2ExpressionBuilder}
import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDelete, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate}
import org.apache.spark.sql.connector.read.LocalScan
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.connector.write.V1Write
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnAndNestedColumn}
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
Expand Down Expand Up @@ -498,7 +499,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}
}

private[sql] object DataSourceV2Strategy {
private[sql] object DataSourceV2Strategy extends Logging {

private def translateLeafNodeFilterV2(predicate: Expression): Option[Predicate] = {
predicate match {
Expand Down Expand Up @@ -582,6 +583,28 @@ private[sql] object DataSourceV2Strategy {
throw new IllegalStateException("Failed to rebuild Expression for filter: " + predicate))
}
}

/**
* Translates a runtime filter into a data source v2 Predicate.
*
* Runtime filters usually contain a subquery that must be evaluated before the translation.
* If the underlying subquery hasn't completed yet, this method will throw an exception.
*/
protected[sql] def translateRuntimeFilterV2(expr: Expression): Option[Predicate] = expr match {
case in @ InSubqueryExec(PushableColumnAndNestedColumn(name), _, _, _, _, _) =>
val values = in.values().getOrElse {
throw new IllegalStateException(s"Can't translate $in to v2 Predicate, no subquery result")
}
val literals = values.map { value =>
val literal = Literal(value)
LiteralValue(literal.value, literal.dataType)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to infer the data type by creating a catalyst Literal. The type must be in.child.dataType

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks

}
Some(new Predicate("IN", FieldReference(name) +: literals))

case other =>
logWarning(s"Can't translate $other to source filter, unsupported expression")
None
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering
import org.apache.spark.sql.connector.read.{SupportsRuntimeFiltering, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
Expand Down Expand Up @@ -85,6 +85,13 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
} else {
None
}
case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) =>
Comment thread
cloud-fan marked this conversation as resolved.
val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r)
if (resExp.references.subsetOf(AttributeSet(filterAttrs))) {
Some(r)
} else {
None
}
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.scalatest.GivenWhenThen
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
Expand Down Expand Up @@ -1805,3 +1805,21 @@ class DynamicPartitionPruningV2SuiteAEOff extends DynamicPartitionPruningV2Suite

class DynamicPartitionPruningV2SuiteAEOn extends DynamicPartitionPruningV2Suite
with EnableAdaptiveExecutionSuite

abstract class DynamicPartitionPruningV2FilterSuite
extends DynamicPartitionPruningDataSourceSuiteBase {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we extend DynamicPartitionPruningV2Suite here? then we can save the override protected def runAnalyzeColumnCommands: Boolean = false, and catalog configs will be overwritten.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I have a follow-up here

override protected def runAnalyzeColumnCommands: Boolean = false

override protected def initState(): Unit = {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableWithV2FilterCatalog].getName)
spark.conf.set("spark.sql.defaultCatalog", "testcat")
}
}

class DynamicPartitionPruningV2FilterSuiteAEOff
extends DynamicPartitionPruningV2FilterSuite
with DisableAdaptiveExecutionSuite

class DynamicPartitionPruningV2FilterSuiteAEOn
extends DynamicPartitionPruningV2FilterSuite
with EnableAdaptiveExecutionSuite