diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala similarity index 93% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 1c1d26a901b4..3ae9bd3399f4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -30,17 +30,16 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} -import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} import org.apache.spark.util.UninterruptibleThread /** - * A [[MicroBatchReadSupport]] that reads data from Kafka. + * A [[MicroBatchStream]] that reads data from Kafka. * * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For @@ -55,13 +54,13 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] class KafkaMicroBatchReadSupport( +private[kafka010] class KafkaMicroBatchStream( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], options: DataSourceOptions, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) extends RateControlMicroBatchReadSupport with Logging { + failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging { private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", @@ -94,16 +93,9 @@ private[kafka010] class KafkaMicroBatchReadSupport( endPartitionOffsets } - override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema - - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets - val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val endPartitionOffsets = end.asInstanceOf[KafkaSourceOffset].partitionToOffsets // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) @@ -168,7 +160,7 @@ private[kafka010] class KafkaMicroBatchReadSupport( }.toArray } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { KafkaMicroBatchReaderFactory } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index b59f21ab130a..58c90b897091 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -47,7 +49,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with CreatableRelationProvider with StreamingWriteSupportProvider with ContinuousReadSupportProvider - with MicroBatchReadSupportProvider + with TableProvider with Logging { import KafkaSourceProvider._ @@ -101,40 +103,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } - /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport]] to read - * batches of Kafka data in a micro-batch streaming query. - */ - override def createMicroBatchReadSupport( - metadataPath: String, - options: DataSourceOptions): KafkaMicroBatchReadSupport = { - - val parameters = options.asMap().asScala.toMap - validateStreamOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) - - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = convertToSpecifiedParams(parameters) - - val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, - STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) - - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - - new KafkaMicroBatchReadSupport( - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - options, - metadataPath, - startingStreamOffsets, - failOnDataLoss(caseInsensitiveParams)) + override def getTable(options: DataSourceOptions): KafkaTable = { + new KafkaTable(strategy(options.asMap().asScala.toMap)) } /** @@ -434,6 +404,52 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister logWarning("maxOffsetsPerTrigger option ignored in batch queries") } } + + class KafkaTable(strategy: => ConsumerStrategy) extends Table + with SupportsMicroBatchRead { + + override def name(): String = s"Kafka $strategy" + + override def schema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def build(): Scan = new KafkaScan(options) + } + } + + class KafkaScan(options: DataSourceOptions) extends Scan { + + override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation) + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = convertToSpecifiedParams(parameters) + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(parameters), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaMicroBatchStream( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + options, + checkpointLocation, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + } } private[kafka010] object KafkaSourceProvider extends Logging { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index 9ba066a4cdc3..2f7fd7f7d47b 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.producer.ProducerRecord import org.apache.spark.sql.Dataset -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec +import org.apache.spark.sql.execution.datasources.v2.ContinuousScanExec import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.streaming.Trigger @@ -208,7 +208,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.executedPlan.collectFirst { - case scan: DataSourceV2StreamingScanExec + case scan: ContinuousScanExec if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists { config => diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5549e821be75..fa3b623586aa 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec +import org.apache.spark.sql.execution.datasources.v2.ContinuousScanExec import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.executedPlan.collectFirst { - case scan: DataSourceV2StreamingScanExec + case scan: ContinuousScanExec if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists(_.knownPartitions.size == newCount), diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index cb453846134e..90b501573a95 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -35,7 +35,7 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.{OldStreamingDataSourceV2Relation, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution @@ -118,11 +118,13 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf val sources: Seq[BaseStreamingSource] = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source - case StreamingExecutionRelation(source: KafkaMicroBatchReadSupport, _) => source + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[KafkaMicroBatchStream] => + r.stream.asInstanceOf[KafkaMicroBatchStream] } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case r: StreamingDataSourceV2Relation + case r: OldStreamingDataSourceV2Relation if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => r.readSupport.asInstanceOf[KafkaContinuousReadSupport] } @@ -1062,9 +1064,10 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { testStream(kafka)( makeSureGetOffsetCalled, AssertOnQuery { query => - query.logicalPlan.collect { - case StreamingExecutionRelation(_: KafkaMicroBatchReadSupport, _) => true - }.nonEmpty + query.logicalPlan.find { + case r: StreamingDataSourceV2Relation => r.stream.isInstanceOf[KafkaMicroBatchStream] + case _ => false + }.isDefined } ) } @@ -1088,13 +1091,12 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val readSupport = provider.createMicroBatchReadSupport( - dir.getAbsolutePath, new DataSourceOptions(options.asJava)) - val config = readSupport.newScanConfigBuilder( + val dsOptions = new DataSourceOptions(options.asJava) + val table = provider.getTable(dsOptions) + val stream = table.newScanBuilder(dsOptions).build().toMicroBatchStream(dir.getAbsolutePath) + val inputPartitions = stream.planInputPartitions( KafkaSourceOffset(Map(tp -> 0L)), - KafkaSourceOffset(Map(tp -> 100L))).build() - val inputPartitions = readSupport.planInputPartitions(config) - .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) + KafkaSourceOffset(Map(tp -> 100L))).map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") { assert(inputPartitions.size == numPartitionsGenerated) inputPartitions.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } @@ -1410,7 +1412,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { val reader = spark .readStream .format("kafka") - .option("startingOffsets", s"latest") + .option("startingOffsets", "latest") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") .option("failOnDataLoss", failOnDataLoss.toString) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java deleted file mode 100644 index c4d9ef88f607..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for micro-batch stream processing. - * - * This interface is used to create {@link MicroBatchReadSupport} instances when end users run - * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. - */ -@Evolving -public interface MicroBatchReadSupportProvider extends DataSourceV2 { - - /** - * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data - * source with a user specified schema, which is called by Spark at the beginning of each - * micro-batch streaming query. - * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. - * - * @param schema the user provided schema. - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - default MicroBatchReadSupport createMicroBatchReadSupport( - StructType schema, - String checkpointLocation, - DataSourceOptions options) { - return DataSourceV2Utils.failForUserSpecifiedSchema(this); - } - - /** - * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data - * source, which is called by Spark at the beginning of each micro-batch streaming query. - * - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - MicroBatchReadSupport createMicroBatchReadSupport( - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java new file mode 100644 index 000000000000..9408e323f9da --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.v2.reader.Scan; +import org.apache.spark.sql.sources.v2.reader.ScanBuilder; + +/** + * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with + * micro-batch mode. + *

+ * If a {@link Table} implements this interface, the + * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that + * builds {@link Scan} with {@link Scan#toMicroBatchStream(String)} implemented. + *

+ */ +@Evolving +public interface SupportsMicroBatchRead extends SupportsRead { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java index 4d84fb19aa02..c60fb2ba0b0b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java @@ -18,8 +18,10 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead; import org.apache.spark.sql.sources.v2.Table; /** @@ -65,4 +67,20 @@ default String description() { default Batch toBatch() { throw new UnsupportedOperationException("Batch scans are not supported"); } + + /** + * Returns the physical representation of this scan for streaming query with micro-batch mode. By + * default this method throws exception, data sources must overwrite this method to provide an + * implementation, if the {@link Table} that creates this scan implements + * {@link SupportsMicroBatchRead}. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Data streams for the same logical source in the same query + * will be given the same checkpointLocation. + * + * @throws UnsupportedOperationException + */ + default MicroBatchStream toMicroBatchStream(String checkpointLocation) { + throw new UnsupportedOperationException("Micro-batch scans are not supported"); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java deleted file mode 100644 index f56066c63938..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.*; - -/** - * An interface that defines how to scan the data from data source for micro-batch streaming - * processing. - * - * The execution engine will get an instance of this interface from a data source provider - * (e.g. {@link org.apache.spark.sql.sources.v2.MicroBatchReadSupportProvider}) at the start of a - * streaming query, then call {@link #newScanConfigBuilder(Offset, Offset)} and create an instance - * of {@link ScanConfig} for each micro-batch. The {@link ScanConfig} will be used to create input - * partitions and reader factory to scan a micro-batch with a Spark job. At the end {@link #stop()} - * will be called when the streaming execution is completed. Note that a single query may have - * multiple executions due to restart or failure recovery. - */ -@Evolving -public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { - - /** - * Returns a builder of {@link ScanConfig}. Spark will call this method and create a - * {@link ScanConfig} for each data scanning job. - * - * The builder can take some query specific information to do operators pushdown, store streaming - * offsets, etc., and keep these information in the created {@link ScanConfig}. - * - * This is the first step of the data scan. All other methods in {@link MicroBatchReadSupport} - * needs to take {@link ScanConfig} as an input. - */ - ScanConfigBuilder newScanConfigBuilder(Offset start, Offset end); - - /** - * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. - */ - PartitionReaderFactory createReaderFactory(ScanConfig config); - - /** - * Returns the most recent offset available. - */ - Offset latestOffset(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java new file mode 100644 index 000000000000..2fb3957293df --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; +import org.apache.spark.sql.sources.v2.reader.Scan; + +/** + * A {@link SparkDataStream} for streaming queries with micro-batch mode. + */ +@Evolving +public interface MicroBatchStream extends SparkDataStream { + + /** + * Returns the most recent offset available. + */ + Offset latestOffset(); + + /** + * Returns a list of {@link InputPartition input partitions} given the start and end offsets. Each + * {@link InputPartition} represents a data split that can be processed by one Spark task. The + * number of input partitions returned here is the same as the number of RDD partitions this scan + * outputs. + *

+ * If the {@link Scan} supports filter pushdown, this stream is likely configured with a filter + * and is responsible for creating splits for that filter, which is not a full scan. + *

+ *

+ * This method will be called multiple times, to launch one Spark job for each micro-batch in this + * data stream. + *

+ */ + InputPartition[] planInputPartitions(Offset start, Offset end); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index 6104175d2c9e..67bff0c27e8a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.Evolving; /** - * An abstract representation of progress through a {@link MicroBatchReadSupport} or + * An abstract representation of progress through a {@link MicroBatchStream} or * {@link ContinuousReadSupport}. * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java new file mode 100644 index 000000000000..8ea34be8d839 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; + +/** + * The base interface representing a readable data stream in a Spark streaming query. It's + * responsible to manage the offsets of the streaming source in the streaming query. + * + * Data sources should implement concrete data stream interfaces: {@link MicroBatchStream}. + */ +@Evolving +public interface SparkDataStream extends BaseStreamingSource { + + /** + * Returns the initial offset for a streaming query to start reading from. Note that the + * streaming data source should not assume that it will start reading from its initial offset: + * if Spark is restarting an existing query, it will restart from the check-pointed offset rather + * than the initial one. + */ + Offset initialOffset(); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java index bd39fc858d3b..9a8c1bdd23be 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java @@ -23,7 +23,7 @@ /** * A base interface for streaming read support. Data sources should implement concrete streaming - * read support interfaces: {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. + * read support interfaces: {@link ContinuousReadSupport}. * This is exposed for a testing purpose. */ @VisibleForTesting diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala index be75fe4f596d..c735b0ef68a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala @@ -26,14 +26,13 @@ import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeSta import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport} /** - * Physical plan node for scanning data from a data source. + * Physical plan node for scanning data from a streaming data source with continuous mode. */ -// TODO: micro-batch should be handled by `DataSourceV2ScanExec`, after we finish the API refactor -// completely. -case class DataSourceV2StreamingScanExec( +// TODO: merge it and `MicroBatchScanExec`. +case class ContinuousScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], @@ -46,7 +45,7 @@ case class DataSourceV2StreamingScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { - case other: DataSourceV2StreamingScanExec => + case other: ContinuousScanExec => output == other.output && readSupport.getClass == other.readSupport.getClass && options == other.options case _ => false @@ -70,7 +69,6 @@ case class DataSourceV2StreamingScanExec( private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) private lazy val readerFactory = readSupport match { - case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) case _ => throw new IllegalStateException("unknown read support: " + readSupport) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 632157818434..63e97e67dc64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,11 +23,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType @@ -92,6 +93,28 @@ case class DataSourceV2Relation( * after we figure out how to apply operator push-down for streaming data sources. */ case class StreamingDataSourceV2Relation( + output: Seq[Attribute], + scanDesc: String, + stream: SparkDataStream, + startOffset: Option[Offset] = None, + endOffset: Option[Offset] = None) + extends LeafNode with MultiInstanceRelation { + + override def isStreaming: Boolean = true + + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + + override def computeStats(): Statistics = stream match { + case r: SupportsReportStatistics => + val statistics = r.estimateStatistics() + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } +} + +// TODO: remove it after finish API refactor for continuous streaming. +case class OldStreamingDataSourceV2Relation( output: Seq[AttributeReference], source: DataSourceV2, options: Map[String, String], @@ -111,7 +134,7 @@ case class StreamingDataSourceV2Relation( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { - case other: StreamingDataSourceV2Relation => + case other: OldStreamingDataSourceV2Relation => output == other.output && readSupport.getClass == other.readSupport.getClass && options == other.options case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 79540b024621..b4c547104c4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchStream} import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode object DataSourceV2Strategy extends Strategy { @@ -125,12 +125,19 @@ object DataSourceV2Strategy extends Strategy { // always add the projection, which will produce unsafe rows required by some operators ProjectExec(project, withFilter) :: Nil - case r: StreamingDataSourceV2Relation => + case r: StreamingDataSourceV2Relation if r.startOffset.isDefined && r.endOffset.isDefined => + val microBatchStream = r.stream.asInstanceOf[MicroBatchStream] + // ensure there is a projection, which will produce unsafe rows required by some operators + ProjectExec(r.output, + MicroBatchScanExec( + r.output, r.scanDesc, microBatchStream, r.startOffset.get, r.endOffset.get)) :: Nil + + case r: OldStreamingDataSourceV2Relation => // TODO: support operator pushdown for streaming data sources. val scanConfig = r.scanConfigBuilder.build() // ensure there is a projection, which will produce unsafe rows required by some operators ProjectExec(r.output, - DataSourceV2StreamingScanExec( + ContinuousScanExec( r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil case WriteToDataSourceV2(writer, query) => @@ -151,7 +158,8 @@ object DataSourceV2Strategy extends Strategy { case Repartition(1, false, child) => val isContinuous = child.find { - case s: StreamingDataSourceV2Relation => s.readSupport.isInstanceOf[ContinuousReadSupport] + case s: OldStreamingDataSourceV2Relation => + s.readSupport.isInstanceOf[ContinuousReadSupport] case _ => false }.isDefined diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala new file mode 100644 index 000000000000..feea8bcb80c8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} + +/** + * Physical plan node for scanning a micro-batch of data from a data source. + */ +case class MicroBatchScanExec( + output: Seq[Attribute], + scanDesc: String, + @transient stream: MicroBatchStream, + @transient start: Offset, + @transient end: Offset) extends LeafExecNode with ColumnarBatchScan { + + override def simpleString(maxFields: Int): String = { + s"ScanV2${truncatedString(output, "[", ", ", "]", maxFields)} $scanDesc" + } + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: MicroBatchScanExec => this.stream == other.stream + case _ => false + } + + override def hashCode(): Int = stream.hashCode() + + private lazy val partitions = stream.planInputPartitions(start, end) + + private lazy val readerFactory = stream.createReaderFactory() + + override def outputPartitioning: physical.Partitioning = stream match { + case _ if partitions.length == 1 => + SinglePartition + + case s: SupportsReportPartitioning => + new DataSourcePartitioning( + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + + case _ => super.outputPartitioning + } + + override def supportsBatch: Boolean = { + require(partitions.forall(readerFactory.supportColumnarReads) || + !partitions.exists(readerFactory.supportColumnarReads), + "Cannot mix row-based and columnar input partitions.") + + partitions.exists(readerFactory.supportColumnarReads) + } + + private lazy val inputRDD: RDD[InternalRow] = { + new DataSourceRDD(sparkContext, partitions, readerFactory, supportsBatch) + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + + override protected def doExecute(): RDD[InternalRow] = { + if (supportsBatch) { + WholeStageCodegenExec(this)(codegenStageId = 0).execute() + } else { + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index db1bf32a156c..64270e1f44a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -22,15 +22,15 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec} -import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateControlMicroBatchStream} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -51,9 +51,6 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readSupportToDataSourceMap = - MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] - private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -69,6 +66,7 @@ class MicroBatchExecution( var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() + val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]() // We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a // map as we go to ensure each identical relation gets the same StreamingExecutionRelation // object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical @@ -90,36 +88,39 @@ class MicroBatchExecution( logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]") StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if - !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => - v2ToExecutionRelationMap.getOrElseUpdate(s, { + case s @ StreamingRelationV2(ds, dsName, table: SupportsMicroBatchRead, options, output, _) + if !disabledSources.contains(ds.getClass.getCanonicalName) => + v2ToRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val readSupport = dataSourceV2.createMicroBatchReadSupport( - metadataPath, - new DataSourceOptions(options.asJava)) nextSourceId += 1 - readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + - s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(readSupport, output)(sparkSession) + logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]") + val dsOptions = new DataSourceOptions(options.asJava) + // TODO: operator pushdown. + val scan = table.newScanBuilder(dsOptions).build() + val stream = scan.toMicroBatchStream(metadataPath) + StreamingDataSourceV2Relation(output, scan.description(), stream) }) - case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => + case s @ StreamingRelationV2(ds, dsName, _, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" if (v1Relation.isEmpty) { throw new UnsupportedOperationException( - s"Data source $sourceName does not support microbatch processing.") + s"Data source $dsName does not support microbatch processing.") } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 - logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]") + logInfo(s"Using Source [$source] from DataSourceV2 named '$dsName' [$ds]") StreamingExecutionRelation(source, output)(sparkSession) }) } - sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } + sources = _logicalPlan.collect { + // v1 source + case s: StreamingExecutionRelation => s.source + // v2 source + case r: StreamingDataSourceV2Relation => r.stream + } uniqueSources = sources.distinct _logicalPlan } @@ -350,7 +351,7 @@ class MicroBatchExecution( reportTimeTaken("getOffset") { (s, s.getOffset) } - case s: RateControlMicroBatchReadSupport => + case s: RateControlMicroBatchStream => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("latestOffset") { val startOffset = availableOffsets @@ -358,7 +359,7 @@ class MicroBatchExecution( .getOrElse(s.initialOffset()) (s, Option(s.latestOffset(startOffset))) } - case s: MicroBatchReadSupport => + case s: MicroBatchStream => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("latestOffset") { (s, Option(s.latestOffset())) @@ -402,8 +403,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (readSupport: MicroBatchReadSupport, off) => - readSupport.commit(readSupport.deserializeOffset(off.json)) + case (stream: MicroBatchStream, off) => + stream.commit(stream.deserializeOffset(off.json)) case (src, _) => throw new IllegalArgumentException( s"Unknown source is found at constructNextBatch: $src") @@ -448,39 +449,30 @@ class MicroBatchExecution( logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but - // to be compatible with streaming source v1, we return a logical plan as a new batch here. - case (readSupport: MicroBatchReadSupport, available) - if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(readSupport).map { - off => readSupport.deserializeOffset(off.json) + case (stream: MicroBatchStream, available) + if committedOffsets.get(stream).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(stream).map { + off => stream.deserializeOffset(off.json) } val endOffset: OffsetV2 = available match { - case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) + case v1: SerializedOffset => stream.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - val startOffset = current.getOrElse(readSupport.initialOffset) - val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) - logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") - - val (source, options) = readSupport match { - // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` - // implementation. We provide a fake one here for explain. - case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] - // Provide a fake value here just in case something went wrong, e.g. the reader gives - // a wrong `equals` implementation. - case _ => readSupportToDataSourceMap.getOrElse(readSupport, { - FakeDataSourceV2 -> Map.empty[String, String] - }) - } - Some(readSupport -> StreamingDataSourceV2Relation( - readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) + val startOffset = current.getOrElse(stream.initialOffset) + logDebug(s"Retrieving data from $stream: $current -> $endOffset") + + // To be compatible with the v1 source, the `newData` is represented as a logical plan, + // while the `newData` of v2 source is just the start and end offsets. Here we return a + // fake logical plan to carry the offsets. + Some(stream -> OffsetHolder(startOffset, endOffset)) + case _ => None } } // Replace sources in the logical plan with data that has arrived since the last batch. val newBatchesPlan = logicalPlan transform { + // For v1 sources. case StreamingExecutionRelation(source, output) => newData.get(source).map { dataPlan => val maxFields = SQLConf.get.maxToStringFields @@ -495,6 +487,15 @@ class MicroBatchExecution( }.getOrElse { LocalRelation(output, isStreaming = true) } + + // For v2 sources. + case r: StreamingDataSourceV2Relation => + newData.get(r.stream).map { + case OffsetHolder(start, end) => + r.copy(startOffset = Some(start), endOffset = Some(end)) + }.getOrElse { + LocalRelation(r.output, isStreaming = true) + } } // Rewire the plan to use the new attributes that were returned by the source. @@ -580,6 +581,6 @@ object MicroBatchExecution { val BATCH_ID_KEY = "streaming.sql.batchId" } -object MemoryStreamDataSource extends DataSourceV2 - -object FakeDataSourceV2 extends DataSourceV2 +case class OffsetHolder(start: OffsetV2, end: OffsetV2) extends LeafNode { + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d1f3f74c5e73..25283515b882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2StreamingScanExec, StreamWriterCommitProgress} -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport +import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2Relation, StreamWriterCommitProgress} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -247,10 +247,12 @@ trait ProgressReporter extends Logging { } val onlyDataSourceV2Sources = { - // Check whether the streaming query's logical plan has only V2 data sources - val allStreamingLeaves = - logicalPlan.collect { case s: StreamingExecutionRelation => s } - allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } + // Check whether the streaming query's logical plan has only V2 micro-batch data sources + val allStreamingLeaves = logicalPlan.collect { + case s: StreamingDataSourceV2Relation => s.stream.isInstanceOf[MicroBatchStream] + case _: StreamingExecutionRelation => false + } + allStreamingLeaves.forall(_ == true) } if (onlyDataSourceV2Sources) { @@ -258,9 +260,9 @@ trait ProgressReporter extends Logging { // (can happen with self-unions or self-joins). This means the source is scanned multiple // times in the query, we should count the numRows for each scan. val sourceToInputRowsTuples = lastExecution.executedPlan.collect { - case s: DataSourceV2StreamingScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => + case s: MicroBatchScanExec => val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = s.readSupport.asInstanceOf[BaseStreamingSource] + val source = s.stream.asInstanceOf[BaseStreamingSource] source -> numRows } logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 4b696dfa5735..535fa1c70b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2, Table} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -94,6 +94,7 @@ case class StreamingExecutionRelation( case class StreamingRelationV2( dataSource: DataSourceV2, sourceName: String, + table: Table, extraOptions: Map[String, String], output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 89033b70f143..c74fa141372d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2StreamingScanExec, StreamingDataSourceV2Relation} +import org.apache.spark.sql.execution.datasources.v2.{ContinuousScanExec, OldStreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2 @@ -64,12 +64,12 @@ class ContinuousExecution( val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( - source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => + source: ContinuousReadSupportProvider, _, _, extraReaderOptions, output, _) => // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? toExecutionRelationMap.getOrElseUpdate(r, { ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) - case StreamingRelationV2(_, sourceName, _, _, _) => + case StreamingRelationV2(_, sourceName, _, _, _, _) => throw new UnsupportedOperationException( s"Data source $sourceName does not support continuous processing.") } @@ -177,7 +177,7 @@ class ContinuousExecution( val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) val startOffset = realOffset.getOrElse(readSupport.initialOffset) val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) - StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) + OldStreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) } // Rewire the plan to use the new attributes that were returned by the source. @@ -211,7 +211,7 @@ class ContinuousExecution( } val (readSupport, scanConfig) = lastExecution.executedPlan.collect { - case scan: DataSourceV2StreamingScanExec + case scan: ContinuousScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig }.head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 13b75ae4a433..5406679630e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,8 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUti import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsMicroBatchRead, Table, TableProvider} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -50,7 +51,7 @@ object MemoryStream { * A base class for memory stream implementations. Supports adding data and resetting. */ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { - protected val encoder = encoderFor[A] + val encoder = encoderFor[A] protected val attributes = encoder.schema.toAttributes def toDS(): Dataset[A] = { @@ -72,16 +73,56 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas def addData(data: TraversableOnce[A]): Offset } +// This class is used to indicate the memory stream data source. We don't actually use it, as +// memory stream is for test only and we never look it up by name. +object MemoryStreamTableProvider extends TableProvider { + override def getTable(options: DataSourceOptions): Table = { + throw new IllegalStateException("MemoryStreamTableProvider should not be used.") + } +} + +class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table with SupportsMicroBatchRead { + + override def name(): String = "MemoryStreamDataSource" + + override def schema(): StructType = stream.fullSchema() + + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MemoryStreamScanBuilder(stream) + } +} + +class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder with Scan { + + override def build(): Scan = this + + override def description(): String = "MemoryStreamDataSource" + + override def readSchema(): StructType = stream.fullSchema() + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { + stream.asInstanceOf[MemoryStream[_]] + } +} + /** * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] * is intended for use in unit tests as it can only replay data when the object is still * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchStream with Logging { + + protected val logicalPlan: LogicalPlan = { + StreamingRelationV2( + MemoryStreamTableProvider, + "memory", + new MemoryStreamTable(this), + Map.empty, + attributes, + None)(sqlContext.sparkSession) + } - protected val logicalPlan: LogicalPlan = - StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -130,14 +171,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (currentOffset.offset == -1) null else currentOffset } - override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startOffset = sc.start.asInstanceOf[LongOffset] - val endOffset = sc.end.get.asInstanceOf[LongOffset] + override def planInputPartitions(start: OffsetV2, end: OffsetV2): Array[InputPartition] = { + val startOffset = start.asInstanceOf[LongOffset] + val endOffset = end.asInstanceOf[LongOffset] synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -159,7 +195,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { MemoryStreamReaderFactory } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index dbcc4483e577..8c5c9eff55ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -50,7 +50,8 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = - StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) + // TODO: don't pass null as table after finish API refactor for continuous stream. + StreamingRelationV2(this, "memory", null, Map(), attributes, None)(sqlContext.sparkSession) // ContinuousReader implementation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchStream.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchStream.scala index 90680ea38fbd..6a66f52c8f73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchStream.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} -// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. -trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { +// A special `MicroBatchStream` that can get latestOffset with a start offset. +trait RateControlMicroBatchStream extends MicroBatchStream { override def latestOffset(): Offset = { throw new IllegalAccessException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala index f5364047adff..a8feed34b96d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala @@ -24,19 +24,23 @@ import java.util.concurrent.TimeUnit import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReadSupport with Logging { +class RateStreamMicroBatchStream( + rowsPerSecond: Long, + // The default values here are used in tests. + rampUpTimeSeconds: Long = 0, + numPartitions: Int = 1, + options: DataSourceOptions, + checkpointLocation: String) + extends MicroBatchStream with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -44,14 +48,6 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock } - private val rowsPerSecond = - options.get(ROWS_PER_SECOND).orElse("1").toLong - - private val rampUpTimeSeconds = - Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) - .map(JavaUtils.timeStringAsSec(_)) - .getOrElse(0L) - private val maxSeconds = Long.MaxValue / rowsPerSecond if (rampUpTimeSeconds > maxSeconds) { @@ -117,16 +113,10 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca LongOffset(json.toLong) } - override def fullSchema(): StructType = SCHEMA - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startSeconds = sc.start.asInstanceOf[LongOffset].offset - val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + val startSeconds = start.asInstanceOf[LongOffset].offset + val endSeconds = end.asInstanceOf[LongOffset].offset assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") if (endSeconds > maxSeconds) { throw new ArithmeticException("Integer overflow. Max offset with " + @@ -148,21 +138,14 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) val relativeMsPerValue = TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - val numPartitions = { - val activeSession = SparkSession.getActiveSession - require(activeSession.isDefined) - Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) - .map(_.toInt) - .getOrElse(activeSession.get.sparkContext.defaultParallelism) - } (0 until numPartitions).map { p => - new RateStreamMicroBatchInputPartition( + RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) }.toArray } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { RateStreamMicroBatchReaderFactory } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6942dfbfe0ec..8d334f0afd0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchStream} import org.apache.spark.sql.types._ /** @@ -39,38 +41,31 @@ import org.apache.spark.sql.types._ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { + with TableProvider with ContinuousReadSupportProvider with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReadSupport( - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { - if (options.get(ROWS_PER_SECOND).isPresent) { - val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") - } + override def getTable(options: DataSourceOptions): Table = { + val rowsPerSecond = options.getLong(ROWS_PER_SECOND, 1) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") } - if (options.get(RAMP_UP_TIME).isPresent) { - val rampUpTimeSeconds = - JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") - } + val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME).orElse(null)) + .map(JavaUtils.timeStringAsSec) + .getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") } - if (options.get(NUM_PARTITIONS).isPresent) { - val numPartitions = options.get(NUM_PARTITIONS).get().toInt - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") - } + val numPartitions = options.getInt( + NUM_PARTITIONS, SparkSession.active.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") } - - new RateStreamMicroBatchReadSupport(options, checkpointLocation) + new RateStreamTable(rowsPerSecond, rampUpTimeSeconds, numPartitions) } override def createContinuousReadSupport( @@ -82,6 +77,31 @@ class RateStreamProvider extends DataSourceV2 override def shortName(): String = "rate" } +class RateStreamTable( + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int) + extends Table with SupportsMicroBatchRead { + + override def name(): String = { + s"RateStream(rowsPerSecond=$rowsPerSecond, rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=$numPartitions)" + } + + override def schema(): StructType = RateStreamProvider.SCHEMA + + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def build(): Scan = new Scan { + override def readSchema(): StructType = RateStreamProvider.SCHEMA + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { + new RateStreamMicroBatchStream( + rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation) + } + } + } +} + object RateStreamProvider { val SCHEMA = StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala index b2a573eae504..ddf398b7752e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala @@ -19,44 +19,29 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket -import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.Calendar import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ListBuffer -import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging -import org.apache.spark.sql._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} -import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} import org.apache.spark.unsafe.types.UTF8String -object TextSocketReader { - val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) - val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: - StructField("timestamp", TimestampType) :: Nil) - val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) -} - /** * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials * and debugging. This MicroBatchReadSupport will *not* work in production applications due to * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReadSupport(options: DataSourceOptions) - extends MicroBatchReadSupport with Logging { - - private val host: String = options.get("host").get() - private val port: Int = options.get("port").get().toInt +class TextSocketMicroBatchStream(host: String, port: Int, options: DataSourceOptions) + extends MicroBatchStream with Logging { @GuardedBy("this") private var socket: Socket = null @@ -99,7 +84,7 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReadSupport.this.synchronized { + TextSocketMicroBatchStream.this.synchronized { val newData = ( UTF8String.fromString(line), DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) @@ -124,22 +109,9 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) LongOffset(json.toLong) } - override def fullSchema(): StructType = { - if (options.getBoolean("includeTimestamp", false)) { - TextSocketReader.SCHEMA_TIMESTAMP - } else { - TextSocketReader.SCHEMA_REGULAR - } - } - - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 - val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + val startOrdinal = start.asInstanceOf[LongOffset].offset.toInt + 1 + val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -164,7 +136,7 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) slices.map(TextSocketInputPartition) } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { new PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val slice = partition.asInstanceOf[TextSocketInputPartition].slice @@ -220,43 +192,3 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) } case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition - -class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupportProvider with ContinuousReadSupportProvider - with DataSourceRegister with Logging { - - private def checkParameters(params: DataSourceOptions): Unit = { - logWarning("The socket source should not be used for production applications! " + - "It does not support recovery.") - if (!params.get("host").isPresent) { - throw new AnalysisException("Set a host to read from with option(\"host\", ...).") - } - if (!params.get("port").isPresent) { - throw new AnalysisException("Set a port to read from with option(\"port\", ...).") - } - Try { - params.get("includeTimestamp").orElse("false").toBoolean - } match { - case Success(_) => - case Failure(_) => - throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") - } - } - - override def createMicroBatchReadSupport( - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { - checkParameters(options) - new TextSocketMicroBatchReadSupport(options) - } - - override def createContinuousReadSupport( - checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = { - checkParameters(options) - new TextSocketContinuousReadSupport(options) - } - - /** String that represents the format that this data source provider uses. */ - override def shortName(): String = "socket" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala new file mode 100644 index 000000000000..35007785b41a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.text.SimpleDateFormat +import java.util.Locale + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchStream} +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +class TextSocketSourceProvider extends DataSourceV2 + with TableProvider with ContinuousReadSupportProvider + with DataSourceRegister with Logging { + + private def checkParameters(params: DataSourceOptions): Unit = { + logWarning("The socket source should not be used for production applications! " + + "It does not support recovery.") + if (!params.get("host").isPresent) { + throw new AnalysisException("Set a host to read from with option(\"host\", ...).") + } + if (!params.get("port").isPresent) { + throw new AnalysisException("Set a port to read from with option(\"port\", ...).") + } + Try { + params.get("includeTimestamp").orElse("false").toBoolean + } match { + case Success(_) => + case Failure(_) => + throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") + } + } + + override def getTable(options: DataSourceOptions): Table = { + checkParameters(options) + new TextSocketTable( + options.get("host").get, + options.getInt("port", -1), + options.getBoolean("includeTimestamp", false)) + } + + override def createContinuousReadSupport( + checkpointLocation: String, + options: DataSourceOptions): ContinuousReadSupport = { + checkParameters(options) + new TextSocketContinuousReadSupport(options) + } + + /** String that represents the format that this data source provider uses. */ + override def shortName(): String = "socket" +} + +class TextSocketTable(host: String, port: Int, includeTimestamp: Boolean) + extends Table with SupportsMicroBatchRead { + + override def name(): String = s"Socket[$host:$port]" + + override def schema(): StructType = { + if (includeTimestamp) { + TextSocketReader.SCHEMA_TIMESTAMP + } else { + TextSocketReader.SCHEMA_REGULAR + } + } + + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def build(): Scan = new Scan { + override def readSchema(): StructType = schema() + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { + new TextSocketMicroBatchStream(host, port, options) + } + } + } +} + +object TextSocketReader { + val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) + val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil) + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 98589da9552c..417dd5584b30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,8 +29,8 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -173,60 +173,54 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupportProvider => + case provider: TableProvider => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = s, conf = sparkSession.sessionState.conf) + ds = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions - val dataSourceOptions = new DataSourceOptions(options.asJava) - var tempReadSupport: MicroBatchReadSupport = null - val schema = try { - val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath - tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createMicroBatchReadSupport( - userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions) - } else { - s.createMicroBatchReadSupport(tmpCheckpointPath, dataSourceOptions) - } - tempReadSupport.fullSchema() - } finally { - // Stop tempReader to avoid side-effect thing - if (tempReadSupport != null) { - tempReadSupport.stop() - tempReadSupport = null - } + val dsOptions = new DataSourceOptions(options.asJava) + val table = userSpecifiedSchema match { + case Some(schema) => provider.getTable(dsOptions, schema) + case _ => provider.getTable(dsOptions) } - Dataset.ofRows( - sparkSession, - StreamingRelationV2( - s, source, options, - schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupportProvider => - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = s, conf = sparkSession.sessionState.conf) - val options = sessionOptions ++ extraOptions - val dataSourceOptions = new DataSourceOptions(options.asJava) - var tempReadSupport: ContinuousReadSupport = null - val schema = try { - val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath - tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createContinuousReadSupport( - userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions) - } else { - s.createContinuousReadSupport(tmpCheckpointPath, dataSourceOptions) - } - tempReadSupport.fullSchema() - } finally { - // Stop tempReader to avoid side-effect thing - if (tempReadSupport != null) { - tempReadSupport.stop() - tempReadSupport = null - } + table match { + case s: SupportsMicroBatchRead => + Dataset.ofRows( + sparkSession, + StreamingRelationV2( + provider, source, s, options, + table.schema.toAttributes, v1Relation)(sparkSession)) + + case _ if ds.isInstanceOf[ContinuousReadSupportProvider] => + val provider = ds.asInstanceOf[ContinuousReadSupportProvider] + var tempReadSupport: ContinuousReadSupport = null + val schema = try { + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + provider.createContinuousReadSupport( + userSpecifiedSchema.get, tmpCheckpointPath, dsOptions) + } else { + provider.createContinuousReadSupport(tmpCheckpointPath, dsOptions) + } + tempReadSupport.fullSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null + } + } + Dataset.ofRows( + sparkSession, + // TODO: do not pass null as table after finish the API refactor for continuous + // stream. + StreamingRelationV2( + provider, source, table = null, options, + schema.toAttributes, v1Relation)(sparkSession)) + + // fallback to v1 + case _ => Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) } - Dataset.ofRows( - sparkSession, - StreamingRelationV2( - s, source, options, - schema.toAttributes, v1Relation)(sparkSession)) + case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index be3efed71403..d40a1fdec0bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -25,15 +25,16 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock -class RateSourceSuite extends StreamTest { +class RateStreamProviderSuite extends StreamTest { import testImplicits._ @@ -41,7 +42,9 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[RateStreamMicroBatchStream] => + r.stream.asInstanceOf[RateStreamMicroBatchStream] }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -51,28 +54,16 @@ class RateSourceSuite extends StreamTest { } } - test("microbatch in registry") { - withTempDir { temp => - DataSource.lookupDataSource("rate", spark.sqlContext.conf). - getConstructor().newInstance() match { - case ds: MicroBatchReadSupportProvider => - val readSupport = ds.createMicroBatchReadSupport( - temp.getCanonicalPath, DataSourceOptions.empty()) - assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } + test("RateStreamProvider in registry") { + val ds = DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() + assert(ds.isInstanceOf[RateStreamProvider], "Could not find rate source") } test("compatible with old path in registry") { - DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).getConstructor().newInstance() match { - case ds: MicroBatchReadSupportProvider => - assert(ds.isInstanceOf[RateStreamProvider]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } + val ds = DataSource.lookupDataSource( + "org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() + assert(ds.isInstanceOf[RateStreamProvider], "Could not find rate source") } test("microbatch - basic") { @@ -142,17 +133,17 @@ class RateSourceSuite extends StreamTest { test("microbatch - infer offsets") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( - new DataSourceOptions( - Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), - temp.getCanonicalPath) - readSupport.clock.asInstanceOf[ManualClock].advance(100000) - val startOffset = readSupport.initialOffset() + val stream = new RateStreamMicroBatchStream( + rowsPerSecond = 100, + options = new DataSourceOptions(Map("useManualClock" -> "true").asJava), + checkpointLocation = temp.getCanonicalPath) + stream.clock.asInstanceOf[ManualClock].advance(100000) + val startOffset = stream.initialOffset() startOffset match { case r: LongOffset => assert(r.offset === 0L) case _ => throw new IllegalStateException("unexpected offset type") } - readSupport.latestOffset() match { + stream.latestOffset() match { case r: LongOffset => assert(r.offset >= 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -161,16 +152,14 @@ class RateSourceSuite extends StreamTest { test("microbatch - predetermined batch size") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), - temp.getCanonicalPath) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createReaderFactory(config) - assert(tasks.size == 1) - val dataReader = readerFactory.createReader(tasks(0)) + val stream = new RateStreamMicroBatchStream( + rowsPerSecond = 20, + options = DataSourceOptions.empty(), + checkpointLocation = temp.getCanonicalPath) + val partitions = stream.planInputPartitions(LongOffset(0L), LongOffset(1L)) + val readerFactory = stream.createReaderFactory() + assert(partitions.size == 1) + val dataReader = readerFactory.createReader(partitions(0)) val data = ArrayBuffer[InternalRow]() while (dataReader.next()) { data.append(dataReader.get()) @@ -181,17 +170,16 @@ class RateSourceSuite extends StreamTest { test("microbatch - data read") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), - temp.getCanonicalPath) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createReaderFactory(config) - assert(tasks.size == 11) - - val readData = tasks + val stream = new RateStreamMicroBatchStream( + rowsPerSecond = 33, + numPartitions = 11, + options = DataSourceOptions.empty(), + checkpointLocation = temp.getCanonicalPath) + val partitions = stream.planInputPartitions(LongOffset(0L), LongOffset(1L)) + val readerFactory = stream.createReaderFactory() + assert(partitions.size == 11) + + val readData = partitions .map(readerFactory.createReader) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[InternalRow]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 7db31f1f8f69..cf069d571081 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -30,10 +30,11 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext @@ -59,7 +60,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source + case r: StreamingDataSourceV2Relation + if r.stream.isInstanceOf[TextSocketMicroBatchStream] => + r.stream.asInstanceOf[TextSocketMicroBatchStream] } if (sources.isEmpty) { throw new Exception( @@ -83,13 +86,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } test("backward compatibility with old path") { - DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", - spark.sqlContext.conf).getConstructor().newInstance() match { - case ds: MicroBatchReadSupportProvider => - assert(ds.isInstanceOf[TextSocketSourceProvider]) - case _ => - throw new IllegalStateException("Could not find socket source") - } + val ds = DataSource.lookupDataSource( + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", + spark.sqlContext.conf).newInstance() + assert(ds.isInstanceOf[TextSocketSourceProvider], "Could not find socket source") } test("basic usage") { @@ -175,16 +175,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map.empty[String, String].asJava)) + provider.getTable(new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.getTable(new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.getTable(new DataSourceOptions(Map("port" -> "1234").asJava)) } } @@ -192,8 +189,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val provider = new TextSocketSourceProvider val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { - val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReadSupport("", a) + provider.getTable(new DataSourceOptions(params.asJava)) } } @@ -204,8 +200,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") val exception = intercept[UnsupportedOperationException] { - provider.createMicroBatchReadSupport( - userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) + provider.getTable(new DataSourceOptions(params.asJava), userSpecifiedSchema) } assert(exception.getMessage.contains( "socket source does not support user-specified schema")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 55fdcee83f11..72321c418f9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ @@ -96,18 +95,16 @@ class StreamSuite extends StreamTest { val streamingRelation = spark.readStream.format("rate").load().logicalPlan collect { case s: StreamingRelationV2 => s } - assert(streamingRelation.nonEmpty, "cannot find StreamingExecutionRelation") + assert(streamingRelation.nonEmpty, "cannot find StreamingRelationV2") assert( streamingRelation.head.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) } test("StreamingExecutionRelation.computeStats") { - val streamingExecutionRelation = MemoryStream[Int].toDF.logicalPlan collect { - case s: StreamingExecutionRelation => s - } - assert(streamingExecutionRelation.nonEmpty, "cannot find StreamingExecutionRelation") - assert(streamingExecutionRelation.head.computeStats.sizeInBytes - == spark.sessionState.conf.defaultSizeInBytes) + val memoryStream = MemoryStream[Int] + val executionRelation = StreamingExecutionRelation( + memoryStream, memoryStream.encoder.schema.toAttributes)(memoryStream.sqlContext.sparkSession) + assert(executionRelation.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) } test("explain join with a normal source") { @@ -495,9 +492,9 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("Streaming RelationV2 MemoryStreamDataSource".r + assert("StreamingDataSourceV2Relation".r .findAllMatchIn(explainWithoutExtended).size === 0) - assert("ScanV2 MemoryStreamDataSource".r + assert("ScanV2".r .findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) @@ -505,9 +502,9 @@ class StreamSuite extends StreamTest { val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("Streaming RelationV2 MemoryStreamDataSource".r + assert("StreamingDataSourceV2Relation".r .findAllMatchIn(explainWithExtended).size === 3) - assert("ScanV2 MemoryStreamDataSource".r + assert("ScanV2".r .findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d878c345c298..b4bd6f6b2edc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -39,12 +39,12 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.AllTuples import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.{OldStreamingDataSourceV2Relation, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -688,8 +688,20 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan .collect { + // v1 source case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.readSupport + // v2 source + case r: StreamingDataSourceV2Relation => r.stream + case r: OldStreamingDataSourceV2Relation => r.readSupport + // We can add data to memory stream before starting it. Then the input plan has + // not been processed by the streaming engine and contains `StreamingRelationV2`. + case r: StreamingRelationV2 if r.sourceName == "memory" => + // TODO: remove this null hack after finish API refactor for continuous stream. + if (r.table == null) { + r.dataSource.asInstanceOf[ContinuousReadSupport] + } else { + r.table.asInstanceOf[MemoryStreamTable].stream + } } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index 46eec736d402..13b8866c22b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -24,15 +24,14 @@ import scala.util.Random import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Dataset} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.BlockingSource import org.apache.spark.util.Utils @@ -304,8 +303,8 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { if (withError) { logDebug(s"Terminating query ${queryToStop.name} with error") queryToStop.asInstanceOf[StreamingQueryWrapper].streamingQuery.logicalPlan.collect { - case StreamingExecutionRelation(source, _) => - source.asInstanceOf[MemoryStream[Int]].addData(0) + case r: StreamingDataSourceV2Relation => + r.stream.asInstanceOf[MemoryStream[Int]].addData(0) } } else { logDebug(s"Stopping query ${queryToStop.name}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 29b816486a1f..62fde98e40dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -220,10 +220,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(start: OffsetV2, end: OffsetV2): Array[InputPartition] = { synchronized { clock.waitTillTime(1150) - super.planInputPartitions(config) + super.planInputPartitions(start, end) } } } @@ -906,7 +906,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(df.logicalPlan.toJSON.contains("StreamingRelationV2")) testStream(df)( - AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation")) + AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation")) ) testStream(df, useV2Sink = true)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 756092fc7ff5..f85cae9fa433 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming.continuous import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec +import org.apache.spark.sql.execution.datasources.v2.ContinuousScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2StreamingScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r + case ContinuousScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r }.get val deltaMs = numTriggers * 1000 + 300 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 31fce46c2dab..d98cc41de9b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -24,26 +24,35 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSupport { +class FakeDataStream extends MicroBatchStream { override def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} - override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - override def fullSchema(): StructType = StructType(Seq()) - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = null override def initialOffset(): Offset = RateStreamOffset(Map()) override def latestOffset(): Offset = RateStreamOffset(Map()) - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def createReaderFactory(): PartitionReaderFactory = { throw new IllegalStateException("fake source - cannot actually read") } +} + +case class FakeReadSupport() extends ContinuousReadSupport { + override def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + override def fullSchema(): StructType = StructType(Seq()) + override def initialOffset(): Offset = RateStreamOffset(Map()) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null override def createContinuousReaderFactory( config: ScanConfig): ContinuousPartitionReaderFactory = { throw new IllegalStateException("fake source - cannot actually read") @@ -53,13 +62,16 @@ case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSu } } -trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { - override def createMicroBatchReadSupport( - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { - LastReadOptions.options = options - FakeReadSupport() - } +class FakeScanBuilder extends ScanBuilder with Scan { + override def build(): Scan = this + override def readSchema(): StructType = StructType(Seq()) + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = new FakeDataStream +} + +class FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder } trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { @@ -84,25 +96,38 @@ trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { class FakeReadMicroBatchOnly extends DataSourceRegister - with FakeMicroBatchReadSupportProvider + with TableProvider with SessionConfigSupport { override def shortName(): String = "fake-read-microbatch-only" override def keyPrefix: String = shortName() + + override def getTable(options: DataSourceOptions): Table = { + LastReadOptions.options = options + new FakeMicroBatchReadTable {} + } } class FakeReadContinuousOnly extends DataSourceRegister + with TableProvider with FakeContinuousReadSupportProvider with SessionConfigSupport { override def shortName(): String = "fake-read-continuous-only" override def keyPrefix: String = shortName() + + override def getTable(options: DataSourceOptions): Table = new Table { + override def schema(): StructType = StructType(Seq()) + override def name(): String = "fake" + } } class FakeReadBothModes extends DataSourceRegister - with FakeMicroBatchReadSupportProvider with FakeContinuousReadSupportProvider { + with TableProvider with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-microbatch-continuous" + + override def getTable(options: DataSourceOptions): Table = new FakeMicroBatchReadTable {} } class FakeReadNeitherMode extends DataSourceRegister { @@ -303,10 +328,18 @@ class StreamingDataSourceV2Suite extends StreamTest { getConstructor().newInstance() val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf). getConstructor().newInstance() + + def isMicroBatch(ds: Any): Boolean = ds match { + case provider: TableProvider => + val table = provider.getTable(DataSourceOptions.empty()) + table.isInstanceOf[SupportsMicroBatchRead] + case _ => false + } + (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) - if !t.isInstanceOf[ContinuousTrigger] => + case (_: TableProvider, _: StreamingWriteSupportProvider, t) + if isMicroBatch(readSource) && !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. @@ -316,7 +349,7 @@ class StreamingDataSourceV2Suite extends StreamTest { // Invalid - can't read at all case (r, _, _) - if !r.isInstanceOf[MicroBatchReadSupportProvider] + if !r.isInstanceOf[TableProvider] && !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") @@ -334,7 +367,7 @@ class StreamingDataSourceV2Suite extends StreamTest { // Invalid - trigger is microbatch but reader is not case (r, _, t) - if !r.isInstanceOf[MicroBatchReadSupportProvider] && + if !isMicroBatch(r) && !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, s"Data source $read does not support microbatch processing")