Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class AvroRowReaderSuite

val df = spark.read.format("avro").load(dir.getCanonicalPath)
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f
case BatchScanExec(_, f: AvroScan, _) => f
}
val filePath = fileScan.get.fileIndex.inputFiles(0)
val fileSize = new File(new URI(filePath)).length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
}.isEmpty)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f
case BatchScanExec(_, f: AvroScan, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
Expand Down Expand Up @@ -2195,7 +2195,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
assert(filterCondition.isDefined)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f
case BatchScanExec(_, f: AvroScan, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
Expand Down Expand Up @@ -2276,7 +2276,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
.where("value = 'a'")

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f
case BatchScanExec(_, f: AvroScan, _) => f
}
assert(fileScan.nonEmpty)
if (filtersPushdown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ default String description() {
* exception, data sources must overwrite this method to provide an implementation, if the
* {@link Table} that creates this scan returns {@link TableCapability#BATCH_READ} support in its
* {@link Table#capabilities()}.
* <p>
* If the scan supports runtime filtering and implements {@link SupportsRuntimeFiltering},

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk @sunchao @viirya, added this comment to clarify the behavior change as discussed.

* this method may be called multiple times. Therefore, implementations can cache some state
* to avoid planning the job twice.
*
* @throws UnsupportedOperationException
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.sources.Filter;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface if they can
* filter initially planned {@link InputPartition}s using predicates Spark infers at runtime.
* <p>
* Note that Spark will push runtime filters only if they are beneficial.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we determine if it is beneficial?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answering my own question this is done using the PartitionPruning rule.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, Spark has a cost-based model in PartitionPruning.

*
* @since 3.2.0
*/
@Experimental
public interface SupportsRuntimeFiltering extends Scan {
/**
* Returns attributes this scan can be filtered by at runtime.
* <p>
* Spark will call {@link #filter(Filter[])} if it can derive a runtime
* predicate for any of the filter attributes.
*/
NamedReference[] filterAttributes();

/**
* Filters this scan using runtime filters.
* <p>
* The provided expressions must be interpreted as a set of filters that are ANDed together.
* Implementations may use the filters to prune initially planned {@link InputPartition}s.
* <p>
* If the scan also implements {@link SupportsReportPartitioning}, it must preserve

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this requirement once SupportsReportPartitioning is changed.

* the originally reported partitioning during runtime filtering. While applying runtime filters,
* the scan may detect that some {@link InputPartition}s have no matching data. It can omit
* such partitions entirely only if it does not report a specific partitioning. Otherwise,
* the scan can replace the initially planned {@link InputPartition}s that have no matching
* data with empty {@link InputPartition}s but must preserve the overall number of partitions.
* <p>
* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime.
*
* @param filters data source filters used to filter the scan at runtime
*/
void filter(Filter[] filters);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.catalog
import java.time.{Instant, ZoneId}
import java.time.temporal.ChronoUnit
import java.util
import java.util.OptionalLong

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand Down Expand Up @@ -245,21 +246,58 @@ class InMemoryTable(
}
}

class InMemoryBatchScan(
data: Array[InputPartition],
case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics

case class InMemoryBatchScan(
var data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType) extends Scan with Batch {
override def readSchema(): StructType = readSchema
tableSchema: StructType)
extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics {

override def toBatch: Batch = this

override def planInputPartitions(): Array[InputPartition] = data
override def estimateStatistics(): Statistics = {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to implement stats as tests rely on them.

if (data.isEmpty) {
return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L))
}

val inputPartitions = data.map(_.asInstanceOf[BufferedRows])
val numRows = inputPartitions.map(_.rows.size).sum
// we assume an average object header is 12 bytes
Comment thread
cloud-fan marked this conversation as resolved.
val objectHeaderSizeInBytes = 12L
val rowSizeInBytes = objectHeaderSizeInBytes + schema.defaultSize
val sizeInBytes = numRows * rowSizeInBytes
InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows))
}

override def planInputPartitions(): Array[InputPartition] = data.toArray

override def createReaderFactory(): PartitionReaderFactory = {
val metadataColumns = readSchema.map(_.name).filter(metadataColumnNames.contains)
val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name))
new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema)
}

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
.filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
}

override def filter(filters: Array[Filter]): Unit = {
if (partitioning.length == 1) {
Comment thread
aokolnychyi marked this conversation as resolved.
Outdated
filters.foreach {
case In(attrName, values) if attrName == partitioning.head.name =>
val matchingKeys = values.map(_.toString).toSet
data = data.filter(partition => {
val key = partition.asInstanceOf[BufferedRows].key
matchingKeys.contains(key)
})

case _ => // skip
}
}
}
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
Expand Down Expand Up @@ -631,6 +631,25 @@ object DataSourceStrategy
}
}

/**
* Translates a runtime filter into a data source filter.
*
* Runtime filters usually contain a subquery that must be evaluated before the translation.
* If the underlying subquery hasn't completed yet, this method will throw an exception.
*/
protected[sql] def translateRuntimeFilter(expr: Expression): Option[Filter] = expr match {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor suggestion, we could maybe add some scaladoc here on the expected behaviour, I note that in one case it rases an exception and in the other case it returns None to indicate it isn't able to handle the filter and I think clarifying that could be good.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Let me know if it is descriptive enough, @holdenk.

case in @ InSubqueryExec(e @ PushableColumnAndNestedColumn(name), _, _, _) =>
val values = in.values().getOrElse {
throw new IllegalStateException(s"Can't translate $in to source filter, no subquery result")
}
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
Some(sources.In(name, values.map(toScala)))

case other =>
logWarning(s"Can't translate $other to source filter, unsupported expression")
None
}

/**
* Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s
* and can be handled by `relation`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,96 @@

package org.apache.spark.sql.execution.datasources.v2

import com.google.common.base.Objects

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan}
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy

/**
* Physical plan node for scanning a batch of data from a data source v2.
*/
case class BatchScanExec(
output: Seq[AttributeReference],
@transient scan: Scan) extends DataSourceV2ScanExecBase {
@transient scan: Scan,
runtimeFilters: Seq[Expression]) extends DataSourceV2ScanExecBase {

@transient lazy val batch = scan.toBatch

// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {
case other: BatchScanExec => this.batch == other.batch
case _ => false
case other: BatchScanExec =>
this.batch == other.batch && this.runtimeFilters == other.runtimeFilters
case _ =>
false
}

override def hashCode(): Int = batch.hashCode()
override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters)

@transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()

@transient private lazy val filteredPartitions: Seq[InputPartition] = {
val dataSourceFilters = runtimeFilters.flatMap {
case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e)
case _ => None
}

if (dataSourceFilters.nonEmpty) {
val originalPartitioning = outputPartitioning

// the cast is safe as runtime filters are only assigned if the scan can be filtered
val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering]
filterableScan.filter(dataSourceFilters.toArray)

// call toBatch again to get filtered partitions
val newPartitions = scan.toBatch.planInputPartitions()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling toBatch again is a bit questionable. Any ideas are welcome.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though I am calling toBatch one more time, I still use the original readerFactory.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the comment in Batch.java since we are now calling planInputPartitions more than once and people might put some logic there that they assume is only run once.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think the comment in Batch still holds as we call planInputPartitions on a given Batch only once. I guess we need to adapt the Scan docs to point that toBatch can now be called multiple times.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to Scan.


originalPartitioning match {
case p: DataSourcePartitioning if p.numPartitions != newPartitions.size =>
throw new SparkException(
"Data source must have preserved the original partitioning during runtime filtering; " +
s"reported num partitions: ${p.numPartitions}, " +
s"num partitions after runtime filtering: ${newPartitions.size}")
case _ =>
// no validation is needed as the data source did not report any specific partitioning
}

newPartitions
} else {
partitions
}
}

override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()

override lazy val inputRDD: RDD[InternalRow] = {
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics)
if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this possible if we already check the number of partition in originalPartitioning must match new partition number?

@aokolnychyi aokolnychyi Jul 1, 2021

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check the number of partitions before and after filtering match only if the source reported a specific partitioning through SupportsReportPartitioning. Only in that case we have DataSourcePartitioning. This situation, on the other hand, can happen if we inferred SinglePartition but the source did not report anything.

// return an empty RDD with 1 partition if dynamic filtering removed the only split
sparkContext.parallelize(Array.empty[InternalRow], 1)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning an empty RDD here.

} else {
new DataSourceRDD(
sparkContext, filteredPartitions, readerFactory, supportsColumnar, customMetrics)
}
}

override def doCanonicalize(): BatchScanExec = {
this.copy(output = output.map(QueryPlan.normalizeExpressions(_, output)))
this.copy(
output = output.map(QueryPlan.normalizeExpressions(_, output)),
runtimeFilters = QueryPlan.normalizePredicates(
runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)),
output))
}

override def simpleString(maxFields: Int): String = {
val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields)
val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}"
val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString"
redact(result)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._

import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, NamedExpression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL
Expand Down Expand Up @@ -114,8 +114,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
// projection and filters were already pushed down in the optimizer.
// this uses PhysicalOperation to get the projection and ensure that if the batch scan does
// not support columnar, a projection is added to convert the rows to UnsafeRow.
val batchExec = BatchScanExec(relation.output, relation.scan)
withProjectAndFilter(project, filters, batchExec, !batchExec.supportsColumnar) :: Nil
val (runtimeFilters, postScanFilters) = filters.partition {
case _: DynamicPruning => true
case _ => false
}
val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters)
withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil

case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation)
if r.startOffset.isDefined && r.endOffset.isDefined =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation

/**
* Removes the filter nodes with dynamic pruning that were not pushed down to the scan.
Expand All @@ -42,6 +43,7 @@ object CleanupDynamicPruningFilters extends Rule[LogicalPlan] with PredicateHelp
_.containsAnyPattern(DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY)) {
// pass through anything that is pushed down into PhysicalOperation
case p @ PhysicalOperation(_, _, LogicalRelation(_: HadoopFsRelation, _, _, _)) => p
case p @ PhysicalOperation(_, _, _: DataSourceV2ScanRelation) => p
// remove any Filters with DynamicPruning that didn't get pushed down to PhysicalOperation.
case f @ Filter(condition, _) =>
val newCondition = condition.transformWithPruning(
Expand Down
Loading