Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ class AvroScanBuilder (
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
override def build(): Scan = {
AvroScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options)
AvroScan(sparkSession, fileIndex, dataSchema, readDataSchema, readPartitionSchema, options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
class KafkaScan(options: CaseInsensitiveStringMap) extends Scan {
val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)

override def readSchema(): StructType = {
KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
}

override def toBatch(): Batch = {
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
validateBatchOptions(caseInsensitiveOptions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.read.streaming.ContinuousStream;
import org.apache.spark.sql.connector.read.streaming.MicroBatchStream;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;

/**
* A logical representation of a data source scan. This interface is used to provide logical
* information, like what the actual read schema is.
* A logical representation of a data source scan.
* <p>
* This logical representation is shared between batch scan, micro-batch streaming scan and
* continuous streaming scan. Data sources must implement the corresponding methods in this
Expand All @@ -38,16 +36,10 @@
@Evolving
public interface Scan {

/**
* Returns the actual schema of this data source scan, which may be different from the physical
* schema of the underlying storage, as column pruning or other optimizations may happen.
*/
StructType readSchema();

/**
* A description string of this scan, which may includes information like: what filters are
* configured for this scan, what's the value of some important options like path, etc. The
* description doesn't need to include {@link #readSchema()}, as Spark already knows it.
* description doesn't need to include the schema, as Spark already knows it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather not remove readSchema. The scan should be self-describing back to Spark, and the read schema is a key piece of information. In fact, I'd like to add more methods to access other things, like pushed filters and residual filters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... like pushed filters and residual filters.

hmm, are these already available from SupportsPushDownFilters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are, but pushedFilters should also be available from the Scan.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is the right direction. If we add more pushdown in the future (limit, aggregate, etc.), are we going to add methods to Scan every time?

* <p>
* By default this returns the class name of the implementation. Please override it to provide a
* meaningful description.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@
public interface SupportsPushDownRequiredColumns extends ScanBuilder {

/**
* Applies column pruning w.r.t. the given requiredSchema.
* Applies column pruning w.r.t. the given `requiredSchema`, and returns the pruned schema.
*
* Implementation should try its best to prune the unnecessary columns or nested fields, but it's
* also OK to do the pruning partially, e.g., a data source may not be able to prune nested
* fields, and only prune top-level columns.
*
* Note that, {@link Scan#readSchema()} implementation should take care of the column
* pruning applied here.
*/
void pruneColumns(StructType requiredSchema);
StructType pruneColumns(StructType requiredSchema);
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ class InMemoryTable(
}

class InMemoryBatchScan(data: Array[InputPartition]) extends Scan with Batch {
override def readSchema(): StructType = schema

override def toBatch: Batch = this

override def planInputPartitions(): Array[InputPartition] = data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.sql.{AnalysisException, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables}
import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, TableCapability}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

object DataSourceV2Strategy extends Strategy with PredicateHelper {
Expand Down Expand Up @@ -83,31 +84,24 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*
* @return the created `ScanConfig`(since column pruning is the last step of operator pushdown),
* and new output attributes after column pruning.
* @return the pruned schema if column pruning is applied.
*/
// TODO: nested column pruning.
private def pruneColumns(
scanBuilder: ScanBuilder,
relation: DataSourceV2Relation,
exprs: Seq[Expression]): (Scan, Seq[AttributeReference]) = {
exprs: Seq[Expression]): Option[StructType] = {
scanBuilder match {
case r: SupportsPushDownRequiredColumns =>
val requiredColumns = AttributeSet(exprs.flatMap(_.references))
val neededOutput = relation.output.filter(requiredColumns.contains)
if (neededOutput != relation.output) {
r.pruneColumns(neededOutput.toStructType)
val scan = r.build()
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
scan -> scan.readSchema().toAttributes.map {
// We have to keep the attribute id during transformation.
a => a.withExprId(nameToAttr(a.name).exprId)
}
Some(r.pruneColumns(neededOutput.toStructType))
} else {
r.build() -> relation.output
None
}

case _ => scanBuilder.build() -> relation.output
case _ => None
}
}

Expand All @@ -127,7 +121,17 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
val (pushedFilters, postScanFiltersWithoutSubquery) =
pushFilters(scanBuilder, normalizedFilters)
val postScanFilters = postScanFiltersWithoutSubquery ++ withSubquery
val (scan, output) = pruneColumns(scanBuilder, relation, project ++ postScanFilters)

val maybePrunedSchema = pruneColumns(scanBuilder, relation, project ++ postScanFilters)
val output = maybePrunedSchema match {
case Some(prunedSchema) =>
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
prunedSchema.toAttributes.map {
a => a.withExprId(nameToAttr(a.name).exprId)
}
case _ => relation.output
}

logInfo(
s"""
|Pushing operators to ${relation.name}
Expand All @@ -136,6 +140,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
|Output: ${output.mkString(", ")}
""".stripMargin)

val scan = scanBuilder.build()
val batchExec = BatchScanExec(output, scan)

val filterCondition = postScanFilters.reduceLeftOption(And)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,6 @@ abstract class FileScan(

override def toBatch: Batch = this

override def readSchema(): StructType =
StructType(readDataSchema.fields ++ readPartitionSchema.fields)

// Returns whether the two given arrays of [[Filter]]s are equivalent.
protected def equivalentFilters(a: Array[Filter], b: Array[Filter]): Boolean = {
a.sortBy(_.hashCode()).sameElements(b.sortBy(_.hashCode()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,33 @@ abstract class FileScanBuilder(
dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns {
private val partitionSchema = fileIndex.partitionSchema
private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields)

override def pruneColumns(requiredSchema: StructType): Unit = {
this.requiredSchema = requiredSchema
protected var readDataSchema: StructType = dataSchema
protected var readPartitionSchema: StructType = partitionSchema

override def pruneColumns(requiredSchema: StructType): StructType = {
val requiredNameSet = requiredSchema.fields.map(
PartitioningUtils.getColName(_, isCaseSensitive)).toSet
readDataSchema = getReadDataSchema(requiredNameSet)
readPartitionSchema = getReadPartitionSchema(requiredNameSet)
StructType(readDataSchema.fields ++ readPartitionSchema.fields)
}

protected def readDataSchema(): StructType = {
val requiredNameSet = createRequiredNameSet()
private def getReadDataSchema(requiredNameSet: Set[String]): StructType = {
val partitionNameSet = partitionSchema.fields.map(
PartitioningUtils.getColName(_, isCaseSensitive)).toSet
val fields = dataSchema.fields.filter { field =>
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
requiredNameSet.contains(colName) && !partitionNameSet.contains(colName)
}
StructType(fields)
}

protected def readPartitionSchema(): StructType = {
val requiredNameSet = createRequiredNameSet()
private def getReadPartitionSchema(requiredNameSet: Set[String]): StructType = {
val fields = partitionSchema.fields.filter { field =>
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
requiredNameSet.contains(colName)
}
StructType(fields)
}

private def createRequiredNameSet(): Set[String] =
requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet

private val partitionNameSet: Set[String] =
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ case class CSVScanBuilder(
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {

override def build(): Scan = {
CSVScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options)
CSVScan(sparkSession, fileIndex, dataSchema, readDataSchema, readPartitionSchema, options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ class JsonScanBuilder (
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
override def build(): Scan = {
JsonScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options)
JsonScan(sparkSession, fileIndex, dataSchema, readDataSchema, readPartitionSchema, options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class OrcScanBuilder(

override def build(): Scan = {
OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema,
readDataSchema(), readPartitionSchema(), options, pushedFilters())
readDataSchema, readPartitionSchema, options, pushedFilters())
}

private var _pushedFilters: Array[Filter] = Array.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class ParquetScanBuilder(
override def pushedFilters(): Array[Filter] = pushedParquetFilters

override def build(): Scan = {
ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(),
readPartitionSchema(), pushedParquetFilters, options)
ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema,
readPartitionSchema, pushedParquetFilters, options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ case class TextScanBuilder(
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {

override def build(): Scan = {
TextScan(sparkSession, fileIndex, readDataSchema(), readPartitionSchema(), options)
TextScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder w

override def description(): String = "MemoryStreamDataSource"

override def readSchema(): StructType = stream.fullSchema()

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
stream.asInstanceOf[MicroBatchStream]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ class RateStreamTable(
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan {
override def readSchema(): StructType = RateStreamProvider.SCHEMA

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream =
new RateStreamMicroBatchStream(
rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan {
override def readSchema(): StructType = schema()

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
new TextSocketMicroBatchStream(host, port, numPartitions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,8 @@ static class AdvancedScanBuilder implements ScanBuilder, Scan,
private Filter[] filters = new Filter[0];

@Override
public void pruneColumns(StructType requiredSchema) {
public StructType pruneColumns(StructType requiredSchema) {
this.requiredSchema = requiredSchema;
}

@Override
public StructType readSchema() {
return requiredSchema;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,6 @@
public class JavaSchemaRequiredDataSource implements TableProvider {

class MyScanBuilder extends JavaSimpleScanBuilder {

private StructType schema;

MyScanBuilder(StructType schema) {
this.schema = schema;
}

@Override
public StructType readSchema() {
return schema;
}

@Override
public InputPartition[] planInputPartitions() {
return new InputPartition[0];
Expand All @@ -56,7 +44,7 @@ public StructType schema() {

@Override
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new MyScanBuilder(schema);
return new MyScanBuilder();
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.types.StructType;

abstract class JavaSimpleScanBuilder implements ScanBuilder, Scan, Batch {

Expand All @@ -35,11 +34,6 @@ public Batch toBatch() {
return this;
}

@Override
public StructType readSchema() {
return new StructType().add("i", "int").add("j", "int");
}

@Override
public PartitionReaderFactory createReaderFactory() {
return new JavaSimpleReaderFactory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,6 @@ abstract class SimpleScanBuilder extends ScanBuilder

override def toBatch: Batch = this

override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")

override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory
}

Expand Down Expand Up @@ -483,12 +481,11 @@ class AdvancedScanBuilder extends ScanBuilder
var requiredSchema = new StructType().add("i", "int").add("j", "int")
var filters = Array.empty[Filter]

override def pruneColumns(requiredSchema: StructType): Unit = {
override def pruneColumns(requiredSchema: StructType): StructType = {
this.requiredSchema = requiredSchema
requiredSchema
}

override def readSchema(): StructType = requiredSchema

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
val (supported, unsupported) = filters.partition {
case GreaterThan("i", _: Int) => true
Expand Down Expand Up @@ -562,8 +559,6 @@ class SchemaRequiredDataSource extends TableProvider {

class MyScanBuilder(schema: StructType) extends SimpleScanBuilder {
override def planInputPartitions(): Array[InputPartition] = Array.empty

override def readSchema(): StructType = schema
}

override def getTable(options: CaseInsensitiveStringMap): Table = {
Expand Down
Loading