diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala index 57c83ec68a649..6824efd9880a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala @@ -34,8 +34,8 @@ class BasicInMemoryTableCatalog extends TableCatalog { protected val namespaces: util.Map[List[String], Map[String, String]] = new ConcurrentHashMap[List[String], Map[String, String]]() - protected val tables: util.Map[Identifier, InMemoryTable] = - new ConcurrentHashMap[Identifier, InMemoryTable]() + protected val tables: util.Map[Identifier, Table] = + new ConcurrentHashMap[Identifier, Table]() private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() diff --git a/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java b/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java new file mode 100644 index 0000000000000..c9d7cb1bf80a3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Unstable; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.sources.BaseRelation; +import org.apache.spark.sql.sources.TableScan; + +/** + * A trait that should be implemented by V1 DataSources that would like to leverage the DataSource + * V2 read code paths. + * + * This interface is designed to provide Spark DataSources time to migrate to DataSource V2 and + * will be removed in a future Spark release. + * + * @since 3.0.0 + */ +@Unstable +public interface V1Scan extends Scan { + + /** + * Create an `BaseRelation` with `TableScan` that can scan data from DataSource v1 to RDD[Row]. + * + * @since 3.0.0 + */ + T toV1TableScan(SQLContext context); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/connector/write/V1WriteBuilder.scala b/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java similarity index 73% rename from sql/core/src/main/scala/org/apache/spark/sql/connector/write/V1WriteBuilder.scala rename to sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java index e738ad1ede446..89b567b5231ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/connector/write/V1WriteBuilder.scala +++ b/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.connector.write +package org.apache.spark.sql.connector.write; -import org.apache.spark.annotation.{Experimental, Unstable} -import org.apache.spark.sql.connector.write.streaming.StreamingWrite -import org.apache.spark.sql.sources.InsertableRelation +import org.apache.spark.annotation.Unstable; +import org.apache.spark.sql.sources.InsertableRelation; /** * A trait that should be implemented by V1 DataSources that would like to leverage the DataSource @@ -32,10 +31,8 @@ * * @since 3.0.0 */ -@Experimental @Unstable -trait V1WriteBuilder extends WriteBuilder { - +public interface V1WriteBuilder extends WriteBuilder { /** * Creates an InsertableRelation that allows appending a DataFrame to a * a destination (using data source-specific parameters). The insert method will only be @@ -44,11 +41,5 @@ * * @since 3.0.0 */ - def buildForV1Write(): InsertableRelation - - // These methods cannot be implemented by a V1WriteBuilder. The super class will throw - // an Unsupported OperationException - override final def buildForBatch(): BatchWrite = super.buildForBatch() - - override final def buildForStreaming(): StreamingWrite = super.buildForStreaming() + InsertableRelation buildForV1Write(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index dc7fb7741e7a7..895eeedd86b8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -27,7 +26,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy import org.apache.spark.sql.internal.SQLConf class SparkPlanner( - val sparkContext: SparkContext, + val session: SparkSession, val conf: SQLConf, val experimentalMethods: ExperimentalMethods) extends SparkStrategies { @@ -39,7 +38,7 @@ class SparkPlanner( extraPlanningStrategies ++ ( LogicalQueryStageStrategy :: PythonEvals :: - DataSourceV2Strategy :: + new DataSourceV2Strategy(session) :: FileSourceStrategy :: DataSourceStrategy(conf) :: SpecialLimits :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 418401ac4e5cc..00ad4e0fe0c11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -570,7 +570,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) + protected lazy val singleRowRdd = session.sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d44cb11e28762..e3a0a0a6c34e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -409,14 +409,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with relation: LogicalRelation, output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { - if (relation.relation.needConversion) { - val converters = RowEncoder(StructType.fromAttributes(output)) - rdd.mapPartitions { iterator => - iterator.map(converters.toRow) - } - } else { - rdd.asInstanceOf[RDD[InternalRow]] - } + DataSourceStrategy.toCatalystRDD(relation.relation, output, rdd) } /** @@ -624,4 +617,21 @@ object DataSourceStrategy { (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } + + /** + * Convert RDD of Row into RDD of InternalRow with objects in catalyst types + */ + private[sql] def toCatalystRDD( + relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]): RDD[InternalRow] = { + if (relation.needConversion) { + val converters = RowEncoder(StructType.fromAttributes(output)) + rdd.mapPartitions { iterator => + iterator.map(converters.toRow) + } + } else { + rdd.asInstanceOf[RDD[InternalRow]] + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b452b66e03813..568ffba4854cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -19,40 +19,69 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ -import org.apache.spark.sql.{AnalysisException, Strategy} +import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedTable} -import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, SupportsNamespaces, TableCapability, TableCatalog, TableChange} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.util.CaseInsensitiveStringMap -object DataSourceV2Strategy extends Strategy with PredicateHelper { +class DataSourceV2Strategy(session: SparkSession) extends Strategy with PredicateHelper { import DataSourceV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + private def withProjectAndFilter( + project: Seq[NamedExpression], + filters: Seq[Expression], + scan: LeafExecNode, + needsUnsafeConversion: Boolean): SparkPlan = { + val filterCondition = filters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + if (withFilter.output != project || needsUnsafeConversion) { + ProjectExec(project, withFilter) + } else { + withFilter + } + } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(project, filters, + relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => + val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) + if (v1Relation.schema != scan.readSchema()) { + throw new IllegalArgumentException( + "The fallback v1 relation reports inconsistent schema:\n" + + "Schema of v2 scan: " + scan.readSchema() + "\n" + + "Schema of v1 relation: " + v1Relation.schema) + } + val rdd = v1Relation.buildScan() + val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd) + val originalOutputNames = relation.table.schema().map(_.name) + val requiredColumnsIndex = output.map(_.name).map(originalOutputNames.indexOf) + val dsScan = RowDataSourceScanExec( + output, + requiredColumnsIndex, + translated.toSet, + pushed.toSet, + unsafeRowRDD, + v1Relation, + tableIdentifier = None) + withProjectAndFilter(project, filters, dsScan, needsUnsafeConversion = false) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => // projection and filters were already pushed down in the optimizer. // this uses PhysicalOperation to get the projection and ensure that if the batch scan does // not support columnar, a projection is added to convert the rows to UnsafeRow. val batchExec = BatchScanExec(relation.output, relation.scan) - - val filterCondition = filters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, batchExec)).getOrElse(batchExec) - - val withProjection = if (withFilter.output != project || !batchExec.supportsColumnar) { - ProjectExec(project, withFilter) - } else { - withFilter - } - - withProjection :: Nil + withProjectAndFilter(project, filters, batchExec, !batchExec.supportsColumnar) :: Nil case r: StreamingDataSourceV2Relation if r.startOffset.isDefined && r.endOffset.isDefined => val microBatchStream = r.stream.asInstanceOf[MicroBatchStream] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 09a8a7ebb6ddc..33338b06565c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -34,7 +34,7 @@ object PushDownUtils extends PredicateHelper { */ def pushFilters( scanBuilder: ScanBuilder, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + filters: Seq[Expression]): (Seq[sources.Filter], Seq[Expression]) = { scanBuilder match { case r: SupportsPushDownFilters => // A map from translated data source leaf node filters to original catalyst filter @@ -62,11 +62,7 @@ object PushDownUtils extends PredicateHelper { val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter => DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) } - // The filters which are marked as pushed to this data source - val pushedFilters = r.pushedFilters().map { filter => - DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) - } - (pushedFilters, untranslatableExprs ++ postScanFilters) + (r.pushedFilters(), untranslatableExprs ++ postScanFilters) case _ => (Nil, filters) } @@ -75,7 +71,7 @@ object PushDownUtils extends PredicateHelper { /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * - * @return the created `ScanConfig`(since column pruning is the last step of operator pushdown), + * @return the `Scan` instance (since column pruning is the last step of operator pushdown), * and new output attributes after column pruning. */ def pruneColumns( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 92acd3ba8d902..59089fa6b77e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -21,7 +21,10 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpressi import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.read.{Scan, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.apache.spark.sql.types.StructType object V2ScanRelationPushDown extends Rule[LogicalPlan] { import DataSourceV2Implicits._ @@ -54,7 +57,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { |Output: ${output.mkString(", ")} """.stripMargin) - val scanRelation = DataSourceV2ScanRelation(relation.table, scan, output) + val wrappedScan = scan match { + case v1: V1Scan => + val translated = filters.flatMap(DataSourceStrategy.translateFilter) + V1ScanWrapper(v1, translated, pushedFilters) + case _ => scan + } + + val scanRelation = DataSourceV2ScanRelation(relation.table, wrappedScan, output) val projectionOverSchema = ProjectionOverSchema(output.toStructType) val projectionFunc = (expr: Expression) => expr transformDown { @@ -77,3 +87,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { withProjection } } + +// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by +// the physical v1 scan node. +case class V1ScanWrapper( + v1Scan: V1Scan, + translatedFilters: Seq[sources.Filter], + handledFilters: Seq[sources.Filter]) extends Scan { + override def readSchema(): StructType = v1Scan.readSchema() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index b8e18b89b54bc..bf80a0b1c167a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -50,7 +50,7 @@ class IncrementalExecution( // Modified planner with stateful operations. override val planner: SparkPlanner = new SparkPlanner( - sparkSession.sparkContext, + sparkSession, sparkSession.sessionState.conf, sparkSession.sessionState.experimentalMethods) { override def strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 91c693ab34c8e..eb658e2d8850e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -250,7 +250,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf` and `experimentalMethods` fields. */ protected def planner: SparkPlanner = { - new SparkPlanner(session.sparkContext, conf, experimentalMethods) { + new SparkPlanner(session, conf, experimentalMethods) { override def extraPlanningStrategies: Seq[Strategy] = super.extraPlanningStrategies ++ customPlanningStrategies } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala new file mode 100644 index 0000000000000..8e2c63417b377 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns, V1Scan} +import org.apache.spark.sql.execution.RowDataSourceScanExec +import org.apache.spark.sql.sources.{BaseRelation, Filter, GreaterThan, TableScan} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +abstract class V1ReadFallbackSuite extends QueryTest with SharedSparkSession { + protected def baseTableScan(): DataFrame + + test("full scan") { + val df = baseTableScan() + val v1Scan = df.queryExecution.executedPlan.collect { + case s: RowDataSourceScanExec => s + } + assert(v1Scan.length == 1) + checkAnswer(df, Seq(Row(1, 10), Row(2, 20), Row(3, 30))) + } + + test("column pruning") { + val df = baseTableScan().select("i") + val v1Scan = df.queryExecution.executedPlan.collect { + case s: RowDataSourceScanExec => s + } + assert(v1Scan.length == 1) + assert(v1Scan.head.output.map(_.name) == Seq("i")) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + } + + test("filter push down") { + val df = baseTableScan().filter("i > 1 and j < 30") + val v1Scan = df.queryExecution.executedPlan.collect { + case s: RowDataSourceScanExec => s + } + assert(v1Scan.length == 1) + // `j < 30` can't be pushed. + assert(v1Scan.head.handledFilters.size == 1) + checkAnswer(df, Seq(Row(2, 20))) + } + + test("filter push down + column pruning") { + val df = baseTableScan().filter("i > 1").select("i") + val v1Scan = df.queryExecution.executedPlan.collect { + case s: RowDataSourceScanExec => s + } + assert(v1Scan.length == 1) + assert(v1Scan.head.output.map(_.name) == Seq("i")) + assert(v1Scan.head.handledFilters.size == 1) + checkAnswer(df, Seq(Row(2), Row(3))) + } +} + +class V1ReadFallbackWithDataFrameReaderSuite extends V1ReadFallbackSuite { + override protected def baseTableScan(): DataFrame = { + spark.read.format(classOf[V1ReadFallbackTableProvider].getName).load() + } +} + +class V1ReadFallbackWithCatalogSuite extends V1ReadFallbackSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.catalog.read_fallback", classOf[V1ReadFallbackCatalog].getName) + sql("CREATE TABLE read_fallback.tbl(i int, j int) USING foo") + } + + override def afterAll(): Unit = { + spark.conf.unset("spark.sql.catalog.read_fallback") + super.afterAll() + } + + override protected def baseTableScan(): DataFrame = { + spark.table("read_fallback.tbl") + } +} + +class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + // To simplify the test implementation, only support fixed schema. + if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) { + throw new UnsupportedOperationException + } + val table = new TableWithV1ReadFallback(ident.toString) + tables.put(ident, table) + table + } +} + +object V1ReadFallbackCatalog { + val schema = new StructType().add("i", "int").add("j", "int") +} + +class V1ReadFallbackTableProvider extends TableProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { + new TableWithV1ReadFallback("v1-read-fallback") + } +} + +class TableWithV1ReadFallback(override val name: String) extends Table with SupportsRead { + + override def schema(): StructType = V1ReadFallbackCatalog.schema + + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.BATCH_READ).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new V1ReadFallbackScanBuilder + } + + private class V1ReadFallbackScanBuilder extends ScanBuilder + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { + + private var requiredSchema: StructType = schema() + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + private var filters: Array[Filter] = Array.empty + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported + } + override def pushedFilters(): Array[Filter] = filters + + override def build(): Scan = new V1ReadFallbackScan(requiredSchema, filters) + } + + private class V1ReadFallbackScan( + requiredSchema: StructType, + filters: Array[Filter]) extends V1Scan { + override def readSchema(): StructType = requiredSchema + + override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T = { + new V1TableScan(context, requiredSchema, filters).asInstanceOf[T] + } + } +} + +class V1TableScan( + context: SQLContext, + requiredSchema: StructType, + filters: Array[Filter]) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = context + override def schema: StructType = requiredSchema + override def buildScan(): RDD[Row] = { + val lowerBound = if (filters.isEmpty) { + 0 + } else { + filters.collect { case GreaterThan("i", v: Int) => v }.max + } + val data = Seq(Row(1, 10), Row(2, 20), Row(3, 30)).filter(_.getInt(0) > lowerBound) + val result = if (requiredSchema.length == 2) { + data + } else if (requiredSchema.map(_.name) == Seq("i")) { + data.map(row => Row(row.getInt(0))) + } else if (requiredSchema.map(_.name) == Seq("j")) { + data.map(row => Row(row.getInt(1))) + } else { + throw new UnsupportedOperationException + } + + SparkSession.active.sparkContext.makeRDD(result) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 3df77fec20993..de21a13e6edb8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -97,7 +97,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Planner that takes into account Hive-specific strategies. */ override protected def planner: SparkPlanner = { - new SparkPlanner(session.sparkContext, conf, experimentalMethods) with HiveStrategies { + new SparkPlanner(session, conf, experimentalMethods) with HiveStrategies { override val sparkSession: SparkSession = session override def extraPlanningStrategies: Seq[Strategy] =