diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala index e36c71ef4b1f7..f08d58703b893 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala @@ -31,6 +31,6 @@ class AvroScanBuilder ( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - AvroScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options) + AvroScan(sparkSession, fileIndex, dataSchema, readDataSchema, readPartitionSchema, options) } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index c15f08d78741d..45ccc0085d374 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -420,10 +420,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister class KafkaScan(options: CaseInsensitiveStringMap) extends Scan { val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) - override def readSchema(): StructType = { - KafkaRecordToRowConverter.kafkaSchema(includeHeaders) - } - override def toBatch(): Batch = { val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap) validateBatchOptions(caseInsensitiveOptions) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index c1584a58c117f..caa4d4bc08021 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -20,13 +20,11 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.read.streaming.ContinuousStream; import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; -import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; /** - * A logical representation of a data source scan. This interface is used to provide logical - * information, like what the actual read schema is. + * A logical representation of a data source scan. *
* This logical representation is shared between batch scan, micro-batch streaming scan and * continuous streaming scan. Data sources must implement the corresponding methods in this @@ -38,16 +36,10 @@ @Evolving public interface Scan { - /** - * Returns the actual schema of this data source scan, which may be different from the physical - * schema of the underlying storage, as column pruning or other optimizations may happen. - */ - StructType readSchema(); - /** * A description string of this scan, which may includes information like: what filters are * configured for this scan, what's the value of some important options like path, etc. The - * description doesn't need to include {@link #readSchema()}, as Spark already knows it. + * description doesn't need to include the schema, as Spark already knows it. *
* By default this returns the class name of the implementation. Please override it to provide a * meaningful description. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownRequiredColumns.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownRequiredColumns.java index 97143686d3efc..f11035b24ce90 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownRequiredColumns.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownRequiredColumns.java @@ -29,14 +29,11 @@ public interface SupportsPushDownRequiredColumns extends ScanBuilder { /** - * Applies column pruning w.r.t. the given requiredSchema. + * Applies column pruning w.r.t. the given `requiredSchema`, and returns the pruned schema. * * Implementation should try its best to prune the unnecessary columns or nested fields, but it's * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. - * - * Note that, {@link Scan#readSchema()} implementation should take care of the column - * pruning applied here. */ - void pruneColumns(StructType requiredSchema); + StructType pruneColumns(StructType requiredSchema); } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 414f9d5834868..64002d665df9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -84,8 +84,6 @@ class InMemoryTable( } class InMemoryBatchScan(data: Array[InputPartition]) extends Scan with Batch { - override def readSchema(): StructType = schema - override def toBatch: Batch = this override def planInputPartitions(): Array[InputPartition] = data diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c8d29520bcfce..e53b0c2ad94a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -21,16 +21,17 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables} import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, TableCapability} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap object DataSourceV2Strategy extends Strategy with PredicateHelper { @@ -83,31 +84,24 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * - * @return the created `ScanConfig`(since column pruning is the last step of operator pushdown), - * and new output attributes after column pruning. + * @return the pruned schema if column pruning is applied. */ // TODO: nested column pruning. private def pruneColumns( scanBuilder: ScanBuilder, relation: DataSourceV2Relation, - exprs: Seq[Expression]): (Scan, Seq[AttributeReference]) = { + exprs: Seq[Expression]): Option[StructType] = { scanBuilder match { case r: SupportsPushDownRequiredColumns => val requiredColumns = AttributeSet(exprs.flatMap(_.references)) val neededOutput = relation.output.filter(requiredColumns.contains) if (neededOutput != relation.output) { - r.pruneColumns(neededOutput.toStructType) - val scan = r.build() - val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - scan -> scan.readSchema().toAttributes.map { - // We have to keep the attribute id during transformation. - a => a.withExprId(nameToAttr(a.name).exprId) - } + Some(r.pruneColumns(neededOutput.toStructType)) } else { - r.build() -> relation.output + None } - case _ => scanBuilder.build() -> relation.output + case _ => None } } @@ -127,7 +121,17 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { val (pushedFilters, postScanFiltersWithoutSubquery) = pushFilters(scanBuilder, normalizedFilters) val postScanFilters = postScanFiltersWithoutSubquery ++ withSubquery - val (scan, output) = pruneColumns(scanBuilder, relation, project ++ postScanFilters) + + val maybePrunedSchema = pruneColumns(scanBuilder, relation, project ++ postScanFilters) + val output = maybePrunedSchema match { + case Some(prunedSchema) => + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + prunedSchema.toAttributes.map { + a => a.withExprId(nameToAttr(a.name).exprId) + } + case _ => relation.output + } + logInfo( s""" |Pushing operators to ${relation.name} @@ -136,6 +140,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { |Output: ${output.mkString(", ")} """.stripMargin) + val scan = scanBuilder.build() val batchExec = BatchScanExec(output, scan) val filterCondition = postScanFilters.reduceLeftOption(And) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 55104a2b21deb..9343773eb0000 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -133,9 +133,6 @@ abstract class FileScan( override def toBatch: Batch = this - override def readSchema(): StructType = - StructType(readDataSchema.fields ++ readPartitionSchema.fields) - // Returns whether the two given arrays of [[Filter]]s are equivalent. protected def equivalentFilters(a: Array[Filter], b: Array[Filter]): Boolean = { a.sortBy(_.hashCode()).sameElements(b.sortBy(_.hashCode())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 21bc14c577bdc..d26a11b860d38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -27,14 +27,21 @@ abstract class FileScanBuilder( dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns { private val partitionSchema = fileIndex.partitionSchema private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields) - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + protected var readDataSchema: StructType = dataSchema + protected var readPartitionSchema: StructType = partitionSchema + + override def pruneColumns(requiredSchema: StructType): StructType = { + val requiredNameSet = requiredSchema.fields.map( + PartitioningUtils.getColName(_, isCaseSensitive)).toSet + readDataSchema = getReadDataSchema(requiredNameSet) + readPartitionSchema = getReadPartitionSchema(requiredNameSet) + StructType(readDataSchema.fields ++ readPartitionSchema.fields) } - protected def readDataSchema(): StructType = { - val requiredNameSet = createRequiredNameSet() + private def getReadDataSchema(requiredNameSet: Set[String]): StructType = { + val partitionNameSet = partitionSchema.fields.map( + PartitioningUtils.getColName(_, isCaseSensitive)).toSet val fields = dataSchema.fields.filter { field => val colName = PartitioningUtils.getColName(field, isCaseSensitive) requiredNameSet.contains(colName) && !partitionNameSet.contains(colName) @@ -42,18 +49,11 @@ abstract class FileScanBuilder( StructType(fields) } - protected def readPartitionSchema(): StructType = { - val requiredNameSet = createRequiredNameSet() + private def getReadPartitionSchema(requiredNameSet: Set[String]): StructType = { val fields = partitionSchema.fields.filter { field => val colName = PartitioningUtils.getColName(field, isCaseSensitive) requiredNameSet.contains(colName) } StructType(fields) } - - private def createRequiredNameSet(): Set[String] = - requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet - - private val partitionNameSet: Set[String] = - partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala index 8b486d0344506..98c82da43e838 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -33,6 +33,6 @@ case class CSVScanBuilder( extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - CSVScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options) + CSVScan(sparkSession, fileIndex, dataSchema, readDataSchema, readPartitionSchema, options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala index be53b1b1676f1..601593531ac9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala @@ -31,6 +31,6 @@ class JsonScanBuilder ( options: CaseInsensitiveStringMap) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - JsonScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options) + JsonScan(sparkSession, fileIndex, dataSchema, readDataSchema, readPartitionSchema, options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 8d1d4ec45915b..4d934d78383d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -45,7 +45,7 @@ case class OrcScanBuilder( override def build(): Scan = { OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, - readDataSchema(), readPartitionSchema(), options, pushedFilters()) + readDataSchema, readPartitionSchema, options, pushedFilters()) } private var _pushedFilters: Array[Filter] = Array.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 87db00077e798..fecda4d64dc9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -69,7 +69,7 @@ case class ParquetScanBuilder( override def pushedFilters(): Array[Filter] = pushedParquetFilters override def build(): Scan = { - ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options) + ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema, + readPartitionSchema, pushedParquetFilters, options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala index b2b518c12b01a..57de1a5e7ecb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala @@ -33,6 +33,6 @@ case class TextScanBuilder( extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - TextScan(sparkSession, fileIndex, readDataSchema(), readPartitionSchema(), options) + TextScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 911a526428cf4..dc69872829db3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -121,8 +121,6 @@ class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder w override def description(): String = "MemoryStreamDataSource" - override def readSchema(): StructType = stream.fullSchema() - override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { stream.asInstanceOf[MicroBatchStream] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 3f7b0377f1eab..e4732dc94403d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -93,8 +93,6 @@ class RateStreamTable( } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan { - override def readSchema(): StructType = RateStreamProvider.SCHEMA - override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = new RateStreamMicroBatchStream( rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index fae3cb765c0c9..607ddcc48202b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -85,8 +85,6 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan { - override def readSchema(): StructType = schema() - override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { new TextSocketMicroBatchStream(host, port, numPartitions) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index 9386ab51d64f0..ebbb830eef7e1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -49,12 +49,8 @@ static class AdvancedScanBuilder implements ScanBuilder, Scan, private Filter[] filters = new Filter[0]; @Override - public void pruneColumns(StructType requiredSchema) { + public StructType pruneColumns(StructType requiredSchema) { this.requiredSchema = requiredSchema; - } - - @Override - public StructType readSchema() { return requiredSchema; } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java index 2181887ae54e2..c90278e44dacd 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java @@ -27,18 +27,6 @@ public class JavaSchemaRequiredDataSource implements TableProvider { class MyScanBuilder extends JavaSimpleScanBuilder { - - private StructType schema; - - MyScanBuilder(StructType schema) { - this.schema = schema; - } - - @Override - public StructType readSchema() { - return schema; - } - @Override public InputPartition[] planInputPartitions() { return new InputPartition[0]; @@ -56,7 +44,7 @@ public StructType schema() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { - return new MyScanBuilder(schema); + return new MyScanBuilder(); } }; } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java index 7cbba00420928..8812d0f92708e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java @@ -21,7 +21,6 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; -import org.apache.spark.sql.types.StructType; abstract class JavaSimpleScanBuilder implements ScanBuilder, Scan, Batch { @@ -35,11 +34,6 @@ public Batch toBatch() { return this; } - @Override - public StructType readSchema() { - return new StructType().add("i", "int").add("j", "int"); - } - @Override public PartitionReaderFactory createReaderFactory() { return new JavaSimpleReaderFactory(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 138bbc3f04f64..ce318210045e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -431,8 +431,6 @@ abstract class SimpleScanBuilder extends ScanBuilder override def toBatch: Batch = this - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory } @@ -483,12 +481,11 @@ class AdvancedScanBuilder extends ScanBuilder var requiredSchema = new StructType().add("i", "int").add("j", "int") var filters = Array.empty[Filter] - override def pruneColumns(requiredSchema: StructType): Unit = { + override def pruneColumns(requiredSchema: StructType): StructType = { this.requiredSchema = requiredSchema + requiredSchema } - override def readSchema(): StructType = requiredSchema - override def pushFilters(filters: Array[Filter]): Array[Filter] = { val (supported, unsupported) = filters.partition { case GreaterThan("i", _: Int) => true @@ -562,8 +559,6 @@ class SchemaRequiredDataSource extends TableProvider { class MyScanBuilder(schema: StructType) extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = Array.empty - - override def readSchema(): StructType = schema } override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 22d3750022c57..c1d13a28798fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -42,8 +42,6 @@ import org.apache.spark.util.SerializableConfiguration */ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { - private val tableSchema = new StructType().add("i", "long").add("j", "long") - override def keyPrefix: String = "simpleWritableDataSource" class MyScanBuilder(path: String, conf: Configuration) extends SimpleScanBuilder { @@ -66,8 +64,6 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { val serializableConf = new SerializableConfiguration(conf) new CSVReaderFactory(serializableConf) } - - override def readSchema(): StructType = tableSchema } class MyWriteBuilder(path: String) extends WriteBuilder with SupportsTruncate { @@ -137,7 +133,9 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { private val path = options.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration - override def schema(): StructType = tableSchema + override def schema(): StructType = { + new StructType().add("i", "long").add("j", "long") + } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder(new Path(path).toUri.toString, conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index e9d148c38e6cb..d5215c544f1fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -61,7 +61,6 @@ class FakeDataStream extends MicroBatchStream with ContinuousStream { class FakeScanBuilder extends ScanBuilder with Scan { override def build(): Scan = this - override def readSchema(): StructType = StructType(Seq()) override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = new FakeDataStream override def toContinuousStream(checkpointLocation: String): ContinuousStream = new FakeDataStream }