diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousInputStream.scala
similarity index 83%
rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala
rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousInputStream.scala
index 1753a28fba2f..bd301503d1db 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousInputStream.scala
@@ -30,10 +30,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming._
-import org.apache.spark.sql.types.StructType
/**
- * A [[ContinuousReadSupport]] for data from kafka.
+ * A [[ContinuousInputStream]] that reads data from Kafka.
*
* @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
* read by per-task consumers generated later.
@@ -46,17 +45,22 @@ import org.apache.spark.sql.types.StructType
* scenarios, where some offsets after the specified initial ones can't be
* properly read.
*/
-class KafkaContinuousReadSupport(
+class KafkaContinuousInputStream(
offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
- extends ContinuousReadSupport with Logging {
+ extends ContinuousInputStream with Logging {
private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong
+ // Initialized when creating read support. If this diverges from the partitions at the latest
+ // offsets, we need to reconfigure.
+ // Exposed outside this object only for unit tests.
+ @volatile private[sql] var knownPartitions: Set[TopicPartition] = _
+
override def initialOffset(): Offset = {
val offsets = initialOffsets match {
case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
@@ -67,28 +71,29 @@ class KafkaContinuousReadSupport(
offsets
}
- override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema
-
- override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
- new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss)
- }
-
override def deserializeOffset(json: String): Offset = {
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets
- startOffsets.toSeq.map {
- case (topicPartition, start) =>
- KafkaContinuousInputPartition(
- topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
- }.toArray
- }
+ override def createContinuousScan(start: Offset): ContinuousScan = {
+ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(start)
- override def createContinuousReaderFactory(
- config: ScanConfig): ContinuousPartitionReaderFactory = {
- KafkaContinuousReaderFactory
+ val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
+ val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
+ val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
+
+ val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
+ if (deletedPartitions.nonEmpty) {
+ reportDataLoss(s"Some partitions were deleted: $deletedPartitions")
+ }
+
+ val startOffsets = newPartitionOffsets ++
+ oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
+
+ knownPartitions = startOffsets.keySet
+
+ new KafkaContinuousScan(
+ offsetReader, kafkaParams, pollTimeoutMs, failOnDataLoss, startOffsets)
}
/** Stop this source and free any resources it has allocated. */
@@ -105,9 +110,8 @@ class KafkaContinuousReadSupport(
KafkaSourceOffset(mergedMap)
}
- override def needsReconfiguration(config: ScanConfig): Boolean = {
- val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions
- offsetReader.fetchLatestOffsets().keySet != knownPartitions
+ override def needsReconfiguration(): Boolean = {
+ knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions
}
override def toString(): String = s"KafkaSource[$offsetReader]"
@@ -125,6 +129,25 @@ class KafkaContinuousReadSupport(
}
}
+class KafkaContinuousScan(
+ offsetReader: KafkaOffsetReader,
+ kafkaParams: ju.Map[String, Object],
+ pollTimeoutMs: Long,
+ failOnDataLoss: Boolean,
+ startOffsets: Map[TopicPartition, Long]) extends ContinuousScan {
+
+ override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = {
+ KafkaContinuousReaderFactory
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ startOffsets.toSeq.map { case (topicPartition, start) =>
+ KafkaContinuousInputPartition(
+ topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
+ }.toArray
+ }
+}
+
/**
* An input partition for continuous Kafka processing. This will be serialized and transformed
* into a full reader on executors.
@@ -151,41 +174,6 @@ object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
}
}
-class KafkaContinuousScanConfigBuilder(
- schema: StructType,
- startOffset: Offset,
- offsetReader: KafkaOffsetReader,
- reportDataLoss: String => Unit)
- extends ScanConfigBuilder {
-
- override def build(): ScanConfig = {
- val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset)
-
- val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
- val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
- val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
-
- val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
- if (deletedPartitions.nonEmpty) {
- reportDataLoss(s"Some partitions were deleted: $deletedPartitions")
- }
-
- val startOffsets = newPartitionOffsets ++
- oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
- KafkaContinuousScanConfig(schema, startOffsets)
- }
-}
-
-case class KafkaContinuousScanConfig(
- readSchema: StructType,
- startOffsets: Map[TopicPartition, Long])
- extends ScanConfig {
-
- // Created when building the scan config builder. If this diverges from the partitions at the
- // latest offsets, we need to reconfigure the kafka read support.
- def knownPartitions: Set[TopicPartition] = startOffsets.keySet
-}
-
/**
* A per-task data reader for continuous Kafka processing.
*
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchInputStream.scala
similarity index 92%
rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala
rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchInputStream.scala
index bb4de674c3c7..afacd81043fa 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchInputStream.scala
@@ -29,17 +29,16 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder}
-import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport
+import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
+import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchInputStream
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, MicroBatchScan, Offset}
import org.apache.spark.util.UninterruptibleThread
/**
- * A [[MicroBatchReadSupport]] that reads data from Kafka.
+ * A [[MicroBatchInputStream]] that reads data from Kafka.
*
* The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains
* a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For
@@ -54,13 +53,13 @@ import org.apache.spark.util.UninterruptibleThread
* To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers
* and not use wrong broker addresses.
*/
-private[kafka010] class KafkaMicroBatchReadSupport(
+private[kafka010] class KafkaMicroBatchInputStream(
kafkaOffsetReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
options: DataSourceOptions,
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
- failOnDataLoss: Boolean) extends RateControlMicroBatchReadSupport with Logging {
+ failOnDataLoss: Boolean) extends RateControlMicroBatchInputStream with Logging {
private val pollTimeoutMs = options.getLong(
"kafkaConsumer.pollTimeoutMs",
@@ -93,65 +92,16 @@ private[kafka010] class KafkaMicroBatchReadSupport(
endPartitionOffsets
}
- override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema
-
- override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = {
- new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end))
- }
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val sc = config.asInstanceOf[SimpleStreamingScanConfig]
- val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets
- val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets
-
- // Find the new partitions, and get their earliest offsets
- val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
- val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
- if (newPartitionInitialOffsets.keySet != newPartitions) {
- // We cannot get from offsets for some partitions. It means they got deleted.
- val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet)
- reportDataLoss(
- s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed")
- }
- logInfo(s"Partitions added: $newPartitionInitialOffsets")
- newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) =>
- reportDataLoss(
- s"Added partition $p starts from $o instead of 0. Some data may have been missed")
- }
-
- // Find deleted partitions, and report data loss if required
- val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet)
- if (deletedPartitions.nonEmpty) {
- reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed")
- }
-
- // Use the end partitions to calculate offset ranges to ignore partitions that have
- // been deleted
- val topicPartitions = endPartitionOffsets.keySet.filter { tp =>
- // Ignore partitions that we don't know the from offsets.
- newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp)
- }.toSeq
- logDebug("TopicPartitions: " + topicPartitions.mkString(", "))
-
- // Calculate offset ranges
- val offsetRanges = rangeCalculator.getRanges(
- fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets,
- untilOffsets = endPartitionOffsets,
- executorLocations = getSortedExecutorList())
-
- // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions,
- // that is, concurrent tasks will not read the same TopicPartitions.
- val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size
-
- // Generate factories based on the offset ranges
- offsetRanges.map { range =>
- KafkaMicroBatchInputPartition(
- range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
- }.toArray
- }
-
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- KafkaMicroBatchReaderFactory
+ override def createMicroBatchScan(start: Offset, end: Offset): MicroBatchScan = {
+ new KafkaMicroBatchScan(
+ kafkaOffsetReader,
+ rangeCalculator,
+ executorKafkaParams,
+ pollTimeoutMs,
+ failOnDataLoss,
+ reportDataLoss,
+ start.asInstanceOf[KafkaSourceOffset],
+ end.asInstanceOf[KafkaSourceOffset])
}
override def deserializeOffset(json: String): Offset = {
@@ -229,23 +179,6 @@ private[kafka010] class KafkaMicroBatchReadSupport(
}
}
- private def getSortedExecutorList(): Array[String] = {
-
- def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = {
- if (a.host == b.host) {
- a.executorId > b.executorId
- } else {
- a.host > b.host
- }
- }
-
- val bm = SparkEnv.get.blockManager
- bm.master.getPeers(bm.blockManagerId).toArray
- .map(x => ExecutorCacheTaskLocation(x.host, x.executorId))
- .sortWith(compare)
- .map(_.toString)
- }
-
/**
* If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
* Otherwise, just log a warning.
@@ -294,6 +227,88 @@ private[kafka010] class KafkaMicroBatchReadSupport(
}
}
+private[kafka010] class KafkaMicroBatchScan(
+ kafkaOffsetReader: KafkaOffsetReader,
+ rangeCalculator: KafkaOffsetRangeCalculator,
+ executorKafkaParams: ju.Map[String, Object],
+ pollTimeoutMs: Long,
+ failOnDataLoss: Boolean,
+ reportDataLoss: String => Unit,
+ start: KafkaSourceOffset,
+ end: KafkaSourceOffset) extends MicroBatchScan with Logging {
+
+ override def createReaderFactory(): PartitionReaderFactory = {
+ KafkaMicroBatchReaderFactory
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ val startPartitionOffsets = start.partitionToOffsets
+ val endPartitionOffsets = end.partitionToOffsets
+
+ // Find the new partitions, and get their earliest offsets
+ val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
+ val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
+ if (newPartitionInitialOffsets.keySet != newPartitions) {
+ // We cannot get from offsets for some partitions. It means they got deleted.
+ val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet)
+ reportDataLoss(
+ s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed")
+ }
+ logInfo(s"Partitions added: $newPartitionInitialOffsets")
+ newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) =>
+ reportDataLoss(
+ s"Added partition $p starts from $o instead of 0. Some data may have been missed")
+ }
+
+ // Find deleted partitions, and report data loss if required
+ val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet)
+ if (deletedPartitions.nonEmpty) {
+ reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed")
+ }
+
+ // Use the end partitions to calculate offset ranges to ignore partitions that have
+ // been deleted
+ val topicPartitions = endPartitionOffsets.keySet.filter { tp =>
+ // Ignore partitions that we don't know the from offsets.
+ newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp)
+ }.toSeq
+ logDebug("TopicPartitions: " + topicPartitions.mkString(", "))
+
+ // Calculate offset ranges
+ val offsetRanges = rangeCalculator.getRanges(
+ fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets,
+ untilOffsets = endPartitionOffsets,
+ executorLocations = getSortedExecutorList())
+
+ // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions,
+ // that is, concurrent tasks will not read the same TopicPartitions.
+ val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size
+
+ // Generate factories based on the offset ranges
+ offsetRanges.map { range =>
+ KafkaMicroBatchInputPartition(
+ range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
+ }.toArray
+ }
+
+ private def getSortedExecutorList(): Array[String] = {
+
+ def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = {
+ if (a.host == b.host) {
+ a.executorId > b.executorId
+ } else {
+ a.host > b.host
+ }
+ }
+
+ val bm = SparkEnv.get.blockManager
+ bm.master.getPeers(bm.blockManagerId).toArray
+ .map(x => ExecutorCacheTaskLocation(x.host, x.executorId))
+ .sortWith(compare)
+ .map(_.toString)
+ }
+}
+
/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */
private[kafka010] case class KafkaMicroBatchInputPartition(
offsetRange: KafkaOffsetRange,
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 28c9853bfea9..86f3f38837e7 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.kafka010
import java.{util => ju}
-import java.util.{Locale, Optional, UUID}
+import java.util.{Locale, UUID}
import scala.collection.JavaConverters._
@@ -31,6 +31,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.reader.ScanConfig
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, MicroBatchInputStream}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
@@ -46,8 +48,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
with RelationProvider
with CreatableRelationProvider
with StreamingWriteSupportProvider
- with ContinuousReadSupportProvider
- with MicroBatchReadSupportProvider
+ with Format
with Logging {
import KafkaSourceProvider._
@@ -106,85 +107,96 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
failOnDataLoss(caseInsensitiveParams))
}
- /**
- * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport]] to read
- * batches of Kafka data in a micro-batch streaming query.
- */
- override def createMicroBatchReadSupport(
- metadataPath: String,
- options: DataSourceOptions): KafkaMicroBatchReadSupport = {
-
- val parameters = options.asMap().asScala.toMap
- validateStreamOptions(parameters)
- // Each running query should use its own group id. Otherwise, the query may be only assigned
- // partial data since Kafka will assign partitions to multiple consumers having the same group
- // id. Hence, we should generate a unique id for each query.
- val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
-
- val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
- val specifiedKafkaParams =
- parameters
- .keySet
- .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
- .map { k => k.drop(6).toString -> parameters(k) }
- .toMap
-
- val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
- STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
-
- val kafkaOffsetReader = new KafkaOffsetReader(
- strategy(caseInsensitiveParams),
- kafkaParamsForDriver(specifiedKafkaParams),
- parameters,
- driverGroupIdPrefix = s"$uniqueGroupId-driver")
-
- new KafkaMicroBatchReadSupport(
- kafkaOffsetReader,
- kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
- options,
- metadataPath,
- startingStreamOffsets,
- failOnDataLoss(caseInsensitiveParams))
+ override def getTable(options: DataSourceOptions): KafkaTable.type = {
+ KafkaTable
}
- /**
- * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport]] to read
- * Kafka data in a continuous streaming query.
- */
- override def createContinuousReadSupport(
- metadataPath: String,
- options: DataSourceOptions): KafkaContinuousReadSupport = {
- val parameters = options.asMap().asScala.toMap
- validateStreamOptions(parameters)
- // Each running query should use its own group id. Otherwise, the query may be only assigned
- // partial data since Kafka will assign partitions to multiple consumers having the same group
- // id. Hence, we should generate a unique id for each query.
- val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
-
- val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
- val specifiedKafkaParams =
- parameters
- .keySet
- .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
- .map { k => k.drop(6).toString -> parameters(k) }
- .toMap
+ object KafkaTable extends Table
+ with SupportsMicroBatchRead with SupportsContinuousRead {
- val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
- STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+ override def schema(): StructType = KafkaOffsetReader.kafkaSchema
- val kafkaOffsetReader = new KafkaOffsetReader(
- strategy(caseInsensitiveParams),
- kafkaParamsForDriver(specifiedKafkaParams),
- parameters,
- driverGroupIdPrefix = s"$uniqueGroupId-driver")
+ /**
+ * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchInputStream]] to read
+ * batches of Kafka data in a micro-batch streaming query.
+ */
+ override def createMicroBatchInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): MicroBatchInputStream = {
+ val parameters = options.asMap().asScala.toMap
+ validateStreamOptions(parameters)
+ // Each running query should use its own group id. Otherwise, the query may be only assigned
+ // partial data since Kafka will assign partitions to multiple consumers having the same group
+ // id. Hence, we should generate a unique id for each query.
+ val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${checkpointLocation.hashCode}"
+
+ val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
+ val specifiedKafkaParams =
+ parameters
+ .keySet
+ .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
+ .map { k => k.drop(6).toString -> parameters(k) }
+ .toMap
+
+ val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
+ caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+
+ val kafkaOffsetReader = new KafkaOffsetReader(
+ strategy(caseInsensitiveParams),
+ kafkaParamsForDriver(specifiedKafkaParams),
+ parameters,
+ driverGroupIdPrefix = s"$uniqueGroupId-driver")
+
+ new KafkaMicroBatchInputStream(
+ kafkaOffsetReader,
+ kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
+ options,
+ checkpointLocation,
+ startingStreamOffsets,
+ failOnDataLoss(caseInsensitiveParams))
+ }
- new KafkaContinuousReadSupport(
- kafkaOffsetReader,
- kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
- parameters,
- metadataPath,
- startingStreamOffsets,
- failOnDataLoss(caseInsensitiveParams))
+ /**
+ * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputStream]] to read
+ * Kafka data in a continuous streaming query.
+ */
+ override def createContinuousInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): ContinuousInputStream = {
+ val parameters = options.asMap().asScala.toMap
+ validateStreamOptions(parameters)
+ // Each running query should use its own group id. Otherwise, the query may be only assigned
+ // partial data since Kafka will assign partitions to multiple consumers having the same group
+ // id. Hence, we should generate a unique id for each query.
+ val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${checkpointLocation.hashCode}"
+
+ val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
+ val specifiedKafkaParams =
+ parameters
+ .keySet
+ .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
+ .map { k => k.drop(6).toString -> parameters(k) }
+ .toMap
+
+ val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
+ caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+
+ val kafkaOffsetReader = new KafkaOffsetReader(
+ strategy(caseInsensitiveParams),
+ kafkaParamsForDriver(specifiedKafkaParams),
+ parameters,
+ driverGroupIdPrefix = s"$uniqueGroupId-driver")
+
+ new KafkaContinuousInputStream(
+ kafkaOffsetReader,
+ kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
+ parameters,
+ checkpointLocation,
+ startingStreamOffsets,
+ failOnDataLoss(caseInsensitiveParams))
+ }
}
/**
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
index af510219a6f6..f2b796b78a34 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010
import org.apache.kafka.clients.producer.ProducerRecord
import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.streaming.Trigger
@@ -207,13 +207,13 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
testUtils.createTopic(topic2, partitions = 5)
eventually(timeout(streamingTimeout)) {
assert(
- query.lastExecution.executedPlan.collectFirst {
- case scan: DataSourceV2ScanExec
- if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] =>
- scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig]
- }.exists { config =>
+ query.lastExecution.logical.collectFirst {
+ case r: StreamingDataSourceV2Relation
+ if r.stream.isInstanceOf[KafkaContinuousInputStream] =>
+ r.stream.asInstanceOf[KafkaContinuousInputStream]
+ }.exists { stream =>
// Ensure the new topic is present and the old topic is gone.
- config.knownPartitions.exists(_.topic == topic2)
+ stream.knownPartitions.exists(_.topic == topic2)
},
s"query never reconfigured to new topic $topic2")
}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
index fa6bdc20bd4f..e7ada6b52c37 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.Trigger
@@ -46,10 +46,10 @@ trait KafkaContinuousTest extends KafkaSourceTest {
testUtils.addPartitions(topic, newCount)
eventually(timeout(streamingTimeout)) {
assert(
- query.lastExecution.executedPlan.collectFirst {
- case scan: DataSourceV2ScanExec
- if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] =>
- scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig]
+ query.lastExecution.logical.collectFirst {
+ case r: StreamingDataSourceV2Relation
+ if r.stream.isInstanceOf[KafkaContinuousInputStream] =>
+ r.stream.asInstanceOf[KafkaContinuousInputStream]
}.exists(_.knownPartitions.size == newCount),
s"query never reconfigured to $newCount partitions")
}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 5ee76990b54f..b8712f0ecb78 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -117,13 +117,15 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf
val sources: Seq[BaseStreamingSource] = {
query.get.logicalPlan.collect {
case StreamingExecutionRelation(source: KafkaSource, _) => source
- case StreamingExecutionRelation(source: KafkaMicroBatchReadSupport, _) => source
+ case r: StreamingDataSourceV2Relation
+ if r.stream.isInstanceOf[KafkaMicroBatchInputStream] =>
+ r.stream.asInstanceOf[KafkaMicroBatchInputStream]
} ++ (query.get.lastExecution match {
case null => Seq()
case e => e.logical.collect {
case r: StreamingDataSourceV2Relation
- if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] =>
- r.readSupport.asInstanceOf[KafkaContinuousReadSupport]
+ if r.stream.isInstanceOf[KafkaContinuousInputStream] =>
+ r.stream.asInstanceOf[KafkaContinuousInputStream]
}
})
}.distinct
@@ -978,7 +980,8 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
makeSureGetOffsetCalled,
AssertOnQuery { query =>
query.logicalPlan.collect {
- case StreamingExecutionRelation(_: KafkaMicroBatchReadSupport, _) => true
+ case r: StreamingDataSourceV2Relation
+ if r.stream.isInstanceOf[KafkaMicroBatchInputStream] => true
}.nonEmpty
}
)
@@ -1003,12 +1006,14 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
"kafka.bootstrap.servers" -> testUtils.brokerAddress,
"subscribe" -> topic
) ++ Option(minPartitions).map { p => "minPartitions" -> p}
- val readSupport = provider.createMicroBatchReadSupport(
- dir.getAbsolutePath, new DataSourceOptions(options.asJava))
- val config = readSupport.newScanConfigBuilder(
+ val dsOptions = new DataSourceOptions(options.asJava)
+ val table = provider.getTable(dsOptions)
+ val config = table.newScanConfigBuilder(dsOptions).build()
+ val stream = table.createMicroBatchInputStream(dir.getAbsolutePath, config, dsOptions)
+ val scan = stream.createMicroBatchScan(
KafkaSourceOffset(Map(tp -> 0L)),
- KafkaSourceOffset(Map(tp -> 100L))).build()
- val inputPartitions = readSupport.planInputPartitions(config)
+ KafkaSourceOffset(Map(tp -> 100L)))
+ val inputPartitions = scan.planInputPartitions()
.map(_.asInstanceOf[KafkaMicroBatchInputPartition])
withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") {
assert(inputPartitions.size == numPartitionsGenerated)
@@ -1326,7 +1331,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
val reader = spark
.readStream
.format("kafka")
- .option("startingOffsets", s"latest")
+ .option("startingOffsets", "latest")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("failOnDataLoss", failOnDataLoss.toString)
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java
deleted file mode 100644
index f403dc619e86..000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2;
-
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils;
-import org.apache.spark.sql.sources.v2.reader.BatchReadSupport;
-import org.apache.spark.sql.types.StructType;
-
-/**
- * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
- * provide data reading ability for batch processing.
- *
- * This interface is used to create {@link BatchReadSupport} instances when end users run
- * {@code SparkSession.read.format(...).option(...).load()}.
- */
-@InterfaceStability.Evolving
-public interface BatchReadSupportProvider extends DataSourceV2 {
-
- /**
- * Creates a {@link BatchReadSupport} instance to load the data from this data source with a user
- * specified schema, which is called by Spark at the beginning of each batch query.
- *
- * Spark will call this method at the beginning of each batch query to create a
- * {@link BatchReadSupport} instance.
- *
- * By default this method throws {@link UnsupportedOperationException}, implementations should
- * override this method to handle user specified schema.
- *
- * @param schema the user specified schema.
- * @param options the options for the returned data source reader, which is an immutable
- * case-insensitive string-to-string map.
- */
- default BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) {
- return DataSourceV2Utils.failForUserSpecifiedSchema(this);
- }
-
- /**
- * Creates a {@link BatchReadSupport} instance to scan the data from this data source, which is
- * called by Spark at the beginning of each batch query.
- *
- * @param options the options for the returned data source reader, which is an immutable
- * case-insensitive string-to-string map.
- */
- BatchReadSupport createBatchReadSupport(DataSourceOptions options);
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java
deleted file mode 100644
index 824c290518ac..000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2;
-
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils;
-import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport;
-import org.apache.spark.sql.types.StructType;
-
-/**
- * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
- * provide data reading ability for continuous stream processing.
- *
- * This interface is used to create {@link ContinuousReadSupport} instances when end users run
- * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger.
- */
-@InterfaceStability.Evolving
-public interface ContinuousReadSupportProvider extends DataSourceV2 {
-
- /**
- * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data
- * source with a user specified schema, which is called by Spark at the beginning of each
- * continuous streaming query.
- *
- * By default this method throws {@link UnsupportedOperationException}, implementations should
- * override this method to handle user specified schema.
- *
- * @param schema the user provided schema.
- * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
- * recovery. Readers for the same logical source in the same query
- * will be given the same checkpointLocation.
- * @param options the options for the returned data source reader, which is an immutable
- * case-insensitive string-to-string map.
- */
- default ContinuousReadSupport createContinuousReadSupport(
- StructType schema,
- String checkpointLocation,
- DataSourceOptions options) {
- return DataSourceV2Utils.failForUserSpecifiedSchema(this);
- }
-
- /**
- * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data
- * source, which is called by Spark at the beginning of each continuous streaming query.
- *
- * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
- * recovery. Readers for the same logical source in the same query
- * will be given the same checkpointLocation.
- * @param options the options for the returned data source reader, which is an immutable
- * case-insensitive string-to-string map.
- */
- ContinuousReadSupport createContinuousReadSupport(
- String checkpointLocation,
- DataSourceOptions options);
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java
index 6e31e84bf6c7..257586a4a135 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java
@@ -23,7 +23,7 @@
* The base interface for data source v2. Implementations must have a public, 0-arg constructor.
*
* Note that this is an empty interface. Data source implementations must mix in interfaces such as
- * {@link BatchReadSupportProvider} or {@link BatchWriteSupportProvider}, which can provide
+ * {@link SupportsBatchRead} or {@link BatchWriteSupportProvider}, which can provide
* batch or streaming read/write support instances. Otherwise it's just a dummy data source which
* is un-readable/writable.
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Format.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Format.java
new file mode 100644
index 000000000000..6b54007ba8c2
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Format.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.DataSourceRegister;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * The base interface for data source v2. Implementations must have a public, 0-arg constructor.
+ *
+ * The major responsibility of this interface is to return a {@link Table} for read/write.
+ */
+@InterfaceStability.Evolving
+public interface Format extends DataSourceV2 {
+
+ /**
+ * Return a {@link Table} instance to do read/write with user-specified options.
+ *
+ * @param options the user-specified options that can identify a table, e.g. path, table name,
+ * Kafka topic name, etc. It's an immutable case-insensitive string-to-string map.
+ */
+ Table getTable(DataSourceOptions options);
+
+ /**
+ * Return a {@link Table} instance to do read/write with user-specified schema and options.
+ *
+ * By default this method throws {@link UnsupportedOperationException}, implementations should
+ * override this method to handle user-specified schema.
+ *
+ * @param options the user-specified options that can identify a table, e.g. path, table name,
+ * Kafka topic name, etc. It's an immutable case-insensitive string-to-string map.
+ * @param schema the user-specified schema.
+ */
+ default Table getTable(DataSourceOptions options, StructType schema) {
+ String name;
+ if (this instanceof DataSourceRegister) {
+ name = ((DataSourceRegister) this).shortName();
+ } else {
+ name = this.getClass().getName();
+ }
+ throw new UnsupportedOperationException(
+ name + " source does not support user-specified schema");
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java
deleted file mode 100644
index 61c08e7fa89d..000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2;
-
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils;
-import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport;
-import org.apache.spark.sql.types.StructType;
-
-/**
- * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
- * provide data reading ability for micro-batch stream processing.
- *
- * This interface is used to create {@link MicroBatchReadSupport} instances when end users run
- * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger.
- */
-@InterfaceStability.Evolving
-public interface MicroBatchReadSupportProvider extends DataSourceV2 {
-
- /**
- * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data
- * source with a user specified schema, which is called by Spark at the beginning of each
- * micro-batch streaming query.
- *
- * By default this method throws {@link UnsupportedOperationException}, implementations should
- * override this method to handle user specified schema.
- *
- * @param schema the user provided schema.
- * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
- * recovery. Readers for the same logical source in the same query
- * will be given the same checkpointLocation.
- * @param options the options for the returned data source reader, which is an immutable
- * case-insensitive string-to-string map.
- */
- default MicroBatchReadSupport createMicroBatchReadSupport(
- StructType schema,
- String checkpointLocation,
- DataSourceOptions options) {
- return DataSourceV2Utils.failForUserSpecifiedSchema(this);
- }
-
- /**
- * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data
- * source, which is called by Spark at the beginning of each micro-batch streaming query.
- *
- * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
- * recovery. Readers for the same logical source in the same query
- * will be given the same checkpointLocation.
- * @param options the options for the returned data source reader, which is an immutable
- * case-insensitive string-to-string map.
- */
- MicroBatchReadSupport createMicroBatchReadSupport(
- String checkpointLocation,
- DataSourceOptions options);
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java
new file mode 100644
index 000000000000..be2ab028fe77
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.reader.BatchScan;
+import org.apache.spark.sql.sources.v2.reader.ScanConfig;
+import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder;
+
+/**
+ * A mix-in interface for {@link Table}. Table implementations can mixin this interface to
+ * provide data reading ability for batch processing.
+ */
+@InterfaceStability.Evolving
+public interface SupportsBatchRead extends Table {
+
+ /**
+ * Creates a {@link BatchScan} instance with a {@link ScanConfig} and user-specified options.
+ *
+ * @param config a {@link ScanConfig} which may contains operator pushdown information.
+ * @param options the user-specified options, which is same as the one used to create the
+ * {@link ScanConfigBuilder} that built the given {@link ScanConfig}.
+ */
+ BatchScan createBatchScan(ScanConfig config, DataSourceOptions options);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java
new file mode 100644
index 000000000000..6773a5b40d8c
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.reader.ScanConfig;
+import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder;
+import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputStream;
+
+/**
+ * A mix-in interface for {@link Table}. Table implementations can mixin this interface to
+ * provide data reading ability for continuous stream processing.
+ */
+@InterfaceStability.Evolving
+public interface SupportsContinuousRead extends Table {
+
+ /**
+ * Creates a {@link ContinuousInputStream} instance with a checkpoint location, a
+ * {@link ScanConfig} and user-specified options.
+ *
+ * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
+ * recovery. Input streams for the same logical source in the same query
+ * will be given the same checkpointLocation.
+ * @param config a {@link ScanConfig} which may contains operator pushdown information.
+ * @param options the user-specified options, which is same as the one used to create the
+ * {@link ScanConfigBuilder} that built the given {@link ScanConfig}.
+ */
+ ContinuousInputStream createContinuousInputStream(
+ String checkpointLocation,
+ ScanConfig config,
+ DataSourceOptions options);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java
new file mode 100644
index 000000000000..04818e3a602d
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.reader.ScanConfig;
+import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder;
+import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchInputStream;
+
+/**
+ * A mix-in interface for {@link Table}. Table implementations can mixin this interface to
+ * provide data reading ability for micro-batch stream processing.
+ */
+@InterfaceStability.Evolving
+public interface SupportsMicroBatchRead extends Table {
+
+ /**
+ * Creates a {@link MicroBatchInputStream} instance with a checkpoint location, a
+ * {@link ScanConfig} and user-specified options.
+ *
+ * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
+ * recovery. Input streams for the same logical source in the same query
+ * will be given the same checkpointLocation.
+ * @param config a {@link ScanConfig} which may contains operator pushdown information.
+ * @param options the user-specified options, which is same as the one used to create the
+ * {@link ScanConfigBuilder} that built the given {@link ScanConfig}.
+ */
+ MicroBatchInputStream createMicroBatchInputStream(
+ String checkpointLocation,
+ ScanConfig config,
+ DataSourceOptions options);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java
similarity index 52%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java
index 452ee86675b4..3315306c8aa6 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java
@@ -15,22 +15,34 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources.v2.reader;
+package org.apache.spark.sql.sources.v2;
import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.execution.datasources.v2.NoopScanConfigBuilder;
+import org.apache.spark.sql.sources.v2.reader.ScanConfig;
+import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder;
+import org.apache.spark.sql.types.StructType;
/**
- * An interface that defines how to load the data from data source for batch processing.
+ * An interface representing a logical structured data set of a data source. For example, the
+ * implementation can be a directory on the file system, or a table in the catalog, etc.
*
- * The execution engine will get an instance of this interface from a data source provider
- * (e.g. {@link org.apache.spark.sql.sources.v2.BatchReadSupportProvider}) at the start of a batch
- * query, then call {@link #newScanConfigBuilder()} and create an instance of {@link ScanConfig}.
- * The {@link ScanConfigBuilder} can apply operator pushdown and keep the pushdown result in
- * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader
- * factory to scan data from the data source with a Spark job.
+ * This interface can mixin the following interfaces to support different operations:
+ *
+ * - {@link SupportsBatchRead}: this table can be read in batch queries.
+ * - {@link SupportsMicroBatchRead}: this table can be read in streaming queries with
+ * micro-batch trigger.
+ * - {@link SupportsContinuousRead}: this table can be read in streaming queries with
+ * continuous trigger.
+ *
*/
@InterfaceStability.Evolving
-public interface BatchReadSupport extends ReadSupport {
+public interface Table {
+
+ /**
+ * Returns the schema of this table.
+ */
+ StructType schema();
/**
* Returns a builder of {@link ScanConfig}. Spark will call this method and create a
@@ -38,14 +50,8 @@ public interface BatchReadSupport extends ReadSupport {
*
* The builder can take some query specific information to do operators pushdown, and keep these
* information in the created {@link ScanConfig}.
- *
- * This is the first step of the data scan. All other methods in {@link BatchReadSupport} needs
- * to take {@link ScanConfig} as an input.
- */
- ScanConfigBuilder newScanConfigBuilder();
-
- /**
- * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}.
*/
- PartitionReaderFactory createReaderFactory(ScanConfig config);
+ default ScanConfigBuilder newScanConfigBuilder(DataSourceOptions options) {
+ return new NoopScanConfigBuilder(schema());
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchScan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchScan.java
new file mode 100644
index 000000000000..c97357dced11
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchScan.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.DataSourceOptions;
+import org.apache.spark.sql.sources.v2.SupportsBatchRead;
+import org.apache.spark.sql.sources.v2.Table;
+
+/**
+ * A {@link Scan} for batch queries.
+ *
+ * The execution engine will get an instance of {@link Table} first, then call
+ * {@link Table#newScanConfigBuilder(DataSourceOptions)} and create an instance of
+ * {@link ScanConfig}. The {@link ScanConfigBuilder} can apply operator pushdown and keep the
+ * pushdown result in {@link ScanConfig}. Then
+ * {@link SupportsBatchRead#createBatchScan(ScanConfig, DataSourceOptions)} will be called to create
+ * a {@link BatchScan} instance, which will be used to create input partitions and reader factory to
+ * scan data from the data source with a Spark job.
+ */
+@InterfaceStability.Evolving
+public interface BatchScan extends Scan {
+
+ /**
+ * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}.
+ */
+ PartitionReaderFactory createReaderFactory();
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
index 95c30de907e4..cc9ce4694c3f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
@@ -23,7 +23,7 @@
/**
* A serializable representation of an input partition returned by
- * {@link ReadSupport#planInputPartitions(ScanConfig)}.
+ * {@link Scan#planInputPartitions()}.
*
* Note that {@link InputPartition} will be serialized and sent to executors, then
* {@link PartitionReader} will be created by
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java
similarity index 68%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java
index a58ddb288f1e..cf9ee11d93bd 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java
@@ -18,24 +18,19 @@
package org.apache.spark.sql.sources.v2.reader;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.types.StructType;
/**
- * The base interface for all the batch and streaming read supports. Data sources should implement
- * concrete read support interfaces like {@link BatchReadSupport}.
+ * The base interface for all the batch and streaming scans. Data sources should implement
+ * concrete scan interfaces like {@link BatchScan}.
+ *
+ * A scan is used to create input partitions and reader factory to scan data from the data source
+ * with a Spark job.
*
* If Spark fails to execute any methods in the implementations of this interface (by throwing an
* exception), the read action will fail and no Spark job will be submitted.
*/
@InterfaceStability.Evolving
-public interface ReadSupport {
-
- /**
- * Returns the full schema of this data source, which is usually the physical schema of the
- * underlying storage. This full schema should not be affected by column pruning or other
- * optimizations.
- */
- StructType fullSchema();
+public interface Scan {
/**
* Returns a list of {@link InputPartition input partitions}. Each {@link InputPartition}
@@ -43,8 +38,8 @@ public interface ReadSupport {
* partitions returned here is the same as the number of RDD partitions this scan outputs.
*
* Note that, this may not be a full scan if the data source supports optimization like filter
- * push-down. Implementations should check the input {@link ScanConfig} and adjust the resulting
- * {@link InputPartition input partitions}.
+ * push-down. Implementations should check the {@link ScanConfig} that created this scan and
+ * adjust the resulting {@link InputPartition input partitions}.
*/
- InputPartition[] planInputPartitions(ScanConfig config);
+ InputPartition[] planInputPartitions();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java
index 7462ce282058..495334cb67dd 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java
@@ -22,21 +22,18 @@
/**
* An interface that carries query specific information for the data scanning job, like operator
- * pushdown information and streaming query offsets. This is defined as an empty interface, and data
- * sources should define their own {@link ScanConfig} classes.
+ * pushdown information. This is defined as an empty interface, and data sources should define
+ * their own {@link ScanConfig} classes.
*
- * For APIs that take a {@link ScanConfig} as input, like
- * {@link ReadSupport#planInputPartitions(ScanConfig)},
- * {@link BatchReadSupport#createReaderFactory(ScanConfig)} and
- * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to
- * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source.
+ * {@link Scan} implementations usually need to cast the input {@link ScanConfig} to the concrete
+ * {@link ScanConfig} class of the data source.
*/
@InterfaceStability.Evolving
public interface ScanConfig {
/**
- * Returns the actual schema of this data source reader, which may be different from the physical
- * schema of the underlying storage, as column pruning or other optimizations may happen.
+ * Returns the actual schema of this scan, which may be different from the table schema, as
+ * column pruning or other optimizations may happen.
*
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
index 44799c7d4913..031c7a73c367 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
@@ -23,7 +23,7 @@
/**
* An interface to represent statistics for a data source, which is returned by
- * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}.
+ * {@link SupportsReportStatistics#estimateStatistics()}.
*/
@InterfaceStability.Evolving
public interface Statistics {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
index db62cd451536..cdfc8bd22ab3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
@@ -21,17 +21,17 @@
import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
/**
- * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to
+ * A mix in interface for {@link Scan}. Data sources can implement this interface to
* report data partitioning and try to avoid shuffle at Spark side.
*
- * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition},
+ * Note that, when a {@link Scan} implementation creates exactly one {@link InputPartition},
* Spark may avoid adding a shuffle even if the reader does not implement this interface.
*/
@InterfaceStability.Evolving
-public interface SupportsReportPartitioning extends ReadSupport {
+public interface SupportsReportPartitioning extends Scan {
/**
- * Returns the output data partitioning that this reader guarantees.
+ * Returns the output data partitioning that this scan guarantees.
*/
- Partitioning outputPartitioning(ScanConfig config);
+ Partitioning outputPartitioning();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
index 1831488ba096..ab50e3ff4098 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
@@ -20,7 +20,7 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to
+ * A mix in interface for {@link Scan}. Data sources can implement this interface to
* report statistics to Spark.
*
* As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the
@@ -28,10 +28,10 @@
* not improve query performance until the planner can push operators before getting stats.
*/
@InterfaceStability.Evolving
-public interface SupportsReportStatistics extends ReadSupport {
+public interface SupportsReportStatistics extends Scan {
/**
- * Returns the estimated statistics of this data source scan.
+ * Returns the estimated statistics of this scan.
*/
- Statistics estimateStatistics(ScanConfig config);
+ Statistics estimateStatistics();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
index fb0b6f1df43b..f460f6bfe3bb 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
@@ -19,13 +19,12 @@
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
-import org.apache.spark.sql.sources.v2.reader.ScanConfig;
import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning;
/**
* An interface to represent the output data partitioning for a data source, which is returned by
- * {@link SupportsReportPartitioning#outputPartitioning(ScanConfig)}. Note that this should work
- * like a snapshot. Once created, it should be deterministic and always report the same number of
+ * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a
+ * snapshot. Once created, it should be deterministic and always report the same number of
* partitions and the same "satisfy" result for a certain distribution.
*/
@InterfaceStability.Evolving
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputStream.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputStream.java
new file mode 100644
index 000000000000..6ff1513a41b6
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputStream.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader.streaming;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A {@link InputStream} for a streaming query with continuous mode.
+ */
+@InterfaceStability.Evolving
+public interface ContinuousInputStream extends InputStream {
+
+ /**
+ * Creates a {@link ContinuousScan} instance with a start offset, to scan the data from the start
+ * offset with a end-less Spark job. The job will be terminated if {@link #needsReconfiguration()}
+ * returns false, and the execution engine will call this method again, with a different start
+ * offset, and launch a new end-less Spark job.
+ */
+ ContinuousScan createContinuousScan(Offset start);
+
+ /**
+ * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances
+ * for each partition to a single global offset.
+ */
+ Offset mergeOffsets(PartitionOffset[] offsets);
+
+ /**
+ * The execution engine will call this method in every epoch to determine if new input
+ * partitions need to be generated, which may be required if for example the underlying
+ * source system has had partitions added or removed.
+ *
+ * If true, the query will be shut down and restarted with a new {@link ContinuousScan}
+ * instance.
+ */
+ default boolean needsReconfiguration() {
+ return false;
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java
deleted file mode 100644
index 9a3ad2eb8a80..000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2.reader.streaming;
-
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
-import org.apache.spark.sql.sources.v2.reader.InputPartition;
-import org.apache.spark.sql.sources.v2.reader.ScanConfig;
-import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder;
-
-/**
- * An interface that defines how to load the data from data source for continuous streaming
- * processing.
- *
- * The execution engine will get an instance of this interface from a data source provider
- * (e.g. {@link org.apache.spark.sql.sources.v2.ContinuousReadSupportProvider}) at the start of a
- * streaming query, then call {@link #newScanConfigBuilder(Offset)} and create an instance of
- * {@link ScanConfig} for the duration of the streaming query or until
- * {@link #needsReconfiguration(ScanConfig)} is true. The {@link ScanConfig} will be used to create
- * input partitions and reader factory to scan data with a Spark job for its duration. At the end
- * {@link #stop()} will be called when the streaming execution is completed. Note that a single
- * query may have multiple executions due to restart or failure recovery.
- */
-@InterfaceStability.Evolving
-public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource {
-
- /**
- * Returns a builder of {@link ScanConfig}. Spark will call this method and create a
- * {@link ScanConfig} for each data scanning job.
- *
- * The builder can take some query specific information to do operators pushdown, store streaming
- * offsets, etc., and keep these information in the created {@link ScanConfig}.
- *
- * This is the first step of the data scan. All other methods in {@link ContinuousReadSupport}
- * needs to take {@link ScanConfig} as an input.
- */
- ScanConfigBuilder newScanConfigBuilder(Offset start);
-
- /**
- * Returns a factory, which produces one {@link ContinuousPartitionReader} for one
- * {@link InputPartition}.
- */
- ContinuousPartitionReaderFactory createContinuousReaderFactory(ScanConfig config);
-
- /**
- * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances
- * for each partition to a single global offset.
- */
- Offset mergeOffsets(PartitionOffset[] offsets);
-
- /**
- * The execution engine will call this method in every epoch to determine if new input
- * partitions need to be generated, which may be required if for example the underlying
- * source system has had partitions added or removed.
- *
- * If true, the query will be shut down and restarted with a new {@link ContinuousReadSupport}
- * instance.
- */
- default boolean needsReconfiguration(ScanConfig config) {
- return false;
- }
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousScan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousScan.java
new file mode 100644
index 000000000000..9b9090a810ca
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousScan.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader.streaming;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.DataSourceOptions;
+import org.apache.spark.sql.sources.v2.SupportsContinuousRead;
+import org.apache.spark.sql.sources.v2.Table;
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
+import org.apache.spark.sql.sources.v2.reader.Scan;
+import org.apache.spark.sql.sources.v2.reader.ScanConfig;
+import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder;
+
+/**
+ * A {@link Scan} for streaming queries with continuous mode.
+ *
+ * The execution engine will get an instance of {@link Table} first, then call
+ * {@link Table#newScanConfigBuilder(DataSourceOptions)} and create an instance of
+ * {@link ScanConfig}. The {@link ScanConfigBuilder} can apply operator pushdown and keep the
+ * pushdown result in {@link ScanConfig}. Then
+ * {@link SupportsContinuousRead#createContinuousInputStream(String, ScanConfig, DataSourceOptions)}
+ * will be called to create a {@link ContinuousInputStream} instance. The
+ * {@link ContinuousInputStream} manages offsets and creates a {@link ContinuousScan} instance for
+ * the duration of the streaming query or until {@link ContinuousInputStream#needsReconfiguration()}
+ * returns true. The {@link ContinuousScan} will be used to create input partitions and reader
+ * factory to scan data with a Spark job for its duration. At the end {@link InputStream#stop()}
+ * will be called when the streaming execution is completed. Note that a single query may have
+ * multiple executions due to restart or failure recovery.
+ */
+@InterfaceStability.Evolving
+public interface ContinuousScan extends Scan {
+
+ /**
+ * Returns a factory, which produces one {@link ContinuousPartitionReader} for one
+ * {@link InputPartition}.
+ */
+ ContinuousPartitionReaderFactory createContinuousReaderFactory();
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/InputStream.java
similarity index 76%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/InputStream.java
index 84872d1ebc26..e1b026dcc332 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/InputStream.java
@@ -17,14 +17,18 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
-import org.apache.spark.sql.sources.v2.reader.ReadSupport;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
/**
- * A base interface for streaming read support. This is package private and is invisible to data
- * sources. Data sources should implement concrete streaming read support interfaces:
- * {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}.
+ * An interface representing a readable data stream in a streaming query. It's responsible to manage
+ * the offsets of the streaming source in this streaming query.
+ *
+ * Data sources should implement concrete input stream interfaces: {@link MicroBatchInputStream} and
+ * {@link ContinuousInputStream}.
*/
-interface StreamingReadSupport extends ReadSupport {
+@InterfaceStability.Evolving
+public interface InputStream extends BaseStreamingSource {
/**
* Returns the initial offset for a streaming query to start reading from. Note that the
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchInputStream.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchInputStream.java
new file mode 100644
index 000000000000..2e0e760da7e2
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchInputStream.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader.streaming;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A {@link InputStream} for a streaming query with micro-batch mode.
+ */
+@InterfaceStability.Evolving
+public interface MicroBatchInputStream extends InputStream {
+
+ /**
+ * Creates a {@link MicroBatchScan} instance with a start and end offset, to scan the data within
+ * this offset range with a Spark job.
+ */
+ MicroBatchScan createMicroBatchScan(Offset start, Offset end);
+
+ /**
+ * Returns the most recent offset available.
+ */
+ Offset latestOffset();
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java
deleted file mode 100644
index edb0db11bff2..000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2.reader.streaming;
-
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
-import org.apache.spark.sql.sources.v2.reader.*;
-
-/**
- * An interface that defines how to scan the data from data source for micro-batch streaming
- * processing.
- *
- * The execution engine will get an instance of this interface from a data source provider
- * (e.g. {@link org.apache.spark.sql.sources.v2.MicroBatchReadSupportProvider}) at the start of a
- * streaming query, then call {@link #newScanConfigBuilder(Offset, Offset)} and create an instance
- * of {@link ScanConfig} for each micro-batch. The {@link ScanConfig} will be used to create input
- * partitions and reader factory to scan a micro-batch with a Spark job. At the end {@link #stop()}
- * will be called when the streaming execution is completed. Note that a single query may have
- * multiple executions due to restart or failure recovery.
- */
-@InterfaceStability.Evolving
-public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource {
-
- /**
- * Returns a builder of {@link ScanConfig}. Spark will call this method and create a
- * {@link ScanConfig} for each data scanning job.
- *
- * The builder can take some query specific information to do operators pushdown, store streaming
- * offsets, etc., and keep these information in the created {@link ScanConfig}.
- *
- * This is the first step of the data scan. All other methods in {@link MicroBatchReadSupport}
- * needs to take {@link ScanConfig} as an input.
- */
- ScanConfigBuilder newScanConfigBuilder(Offset start, Offset end);
-
- /**
- * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}.
- */
- PartitionReaderFactory createReaderFactory(ScanConfig config);
-
- /**
- * Returns the most recent offset available.
- */
- Offset latestOffset();
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchScan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchScan.java
new file mode 100644
index 000000000000..45d640af5750
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchScan.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader.streaming;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.DataSourceOptions;
+import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead;
+import org.apache.spark.sql.sources.v2.Table;
+import org.apache.spark.sql.sources.v2.reader.*;
+
+/**
+ * A {@link Scan} for streaming queries with micro-batch mode.
+ *
+ * The execution engine will get an instance of {@link Table} first, then call
+ * {@link Table#newScanConfigBuilder(DataSourceOptions)} and create an instance of
+ * {@link ScanConfig}. The {@link ScanConfigBuilder} can apply operator pushdown and keep the
+ * pushdown result in {@link ScanConfig}. Then
+ * {@link SupportsMicroBatchRead#createMicroBatchInputStream(String, ScanConfig, DataSourceOptions)}
+ * will be called to create a {@link MicroBatchInputStream} instance. The
+ * {@link MicroBatchInputStream} manages offsets and creates a {@link MicroBatchScan} instance for
+ * each micro-batch. The {@link MicroBatchScan} will be used to create input partitions and
+ * reader factory to scan a micro-batch with a Spark job. At the end {@link InputStream#stop()}
+ * will be called when the streaming execution is completed. Note that a single query may have
+ * multiple executions due to restart or failure recovery.
+ */
+@InterfaceStability.Evolving
+public interface MicroBatchScan extends Scan {
+
+ /**
+ * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}.
+ */
+ PartitionReaderFactory createReaderFactory();
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
index 6cf27734867c..d89c96360af2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
@@ -20,8 +20,8 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * An abstract representation of progress through a {@link MicroBatchReadSupport} or
- * {@link ContinuousReadSupport}.
+ * An abstract representation of progress through a {@link MicroBatchScan} or
+ * {@link ContinuousScan}.
* During execution, offsets provided by the data source implementation will be logged and used as
* restart checkpoints. Each source should provide an offset implementation which the source can use
* to reconstruct a position in the stream up to which data has been seen/processed.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 4f6d8b8a0c34..9c9078dfe4e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
-import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, DataSourceOptions, DataSourceV2}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, Format}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
@@ -193,21 +193,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
- if (classOf[DataSourceV2].isAssignableFrom(cls)) {
- val ds = cls.newInstance().asInstanceOf[DataSourceV2]
- if (ds.isInstanceOf[BatchReadSupportProvider]) {
- val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
- ds = ds, conf = sparkSession.sessionState.conf)
- val pathsOption = {
- val objectMapper = new ObjectMapper()
- DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray)
- }
- Dataset.ofRows(sparkSession, DataSourceV2Relation.create(
- ds, sessionOptions ++ extraOptions.toMap + pathsOption,
- userSpecifiedSchema = userSpecifiedSchema))
- } else {
- loadV1Source(paths: _*)
+ if (classOf[Format].isAssignableFrom(cls)) {
+ val format = cls.newInstance().asInstanceOf[Format]
+ val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
+ ds = format, conf = sparkSession.sessionState.conf)
+ val pathsOption = {
+ val objectMapper = new ObjectMapper()
+ DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray)
}
+ DataSourceV2Relation.create(
+ format, sessionOptions ++ extraOptions.toMap + pathsOption,
+ userSpecifiedSchema = userSpecifiedSchema
+ ).map(Dataset.ofRows(sparkSession, _)).getOrElse(loadV1Source(paths: _*))
} else {
loadV1Source(paths: _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 5a28870f5d3c..cdee3de261ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -252,7 +252,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val options = sessionOptions ++ extraOptions
if (mode == SaveMode.Append) {
- val relation = DataSourceV2Relation.create(source, options)
+ val relation = DataSourceV2Relation.create(source, options).getOrElse {
+ throw new AnalysisException(s"data source $source does not support append.")
+ }
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index f7e29593a635..8a593daf4b58 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -24,11 +24,12 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, BatchWriteSupportProvider, DataSourceOptions, DataSourceV2}
-import org.apache.spark.sql.sources.v2.reader.{BatchReadSupport, ReadSupport, ScanConfigBuilder, SupportsReportStatistics}
+import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.reader.{Scan, SupportsReportStatistics}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, InputStream, MicroBatchInputStream, Offset}
import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport
import org.apache.spark.sql.types.StructType
@@ -41,7 +42,7 @@ import org.apache.spark.sql.types.StructType
*/
case class DataSourceV2Relation(
source: DataSourceV2,
- readSupport: BatchReadSupport,
+ table: SupportsBatchRead,
output: Seq[AttributeReference],
options: Map[String, String],
tableIdent: Option[TableIdentifier] = None,
@@ -60,12 +61,16 @@ case class DataSourceV2Relation(
def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema)
- override def computeStats(): Statistics = readSupport match {
- case r: SupportsReportStatistics =>
- val statistics = r.estimateStatistics(readSupport.newScanConfigBuilder().build())
- Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
- case _ =>
- Statistics(sizeInBytes = conf.defaultSizeInBytes)
+ override def computeStats(): Statistics = {
+ val dsOptions = new DataSourceOptions(options.asJava)
+ val config = table.newScanConfigBuilder(dsOptions).build()
+ table.createBatchScan(config, dsOptions) match {
+ case r: SupportsReportStatistics =>
+ val statistics = r.estimateStatistics()
+ Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
+ case _ =>
+ Statistics(sizeInBytes = conf.defaultSizeInBytes)
+ }
}
override def newInstance(): DataSourceV2Relation = {
@@ -81,11 +86,12 @@ case class DataSourceV2Relation(
* after we figure out how to apply operator push-down for streaming data sources.
*/
case class StreamingDataSourceV2Relation(
- output: Seq[AttributeReference],
+ output: Seq[Attribute],
source: DataSourceV2,
options: Map[String, String],
- readSupport: ReadSupport,
- scanConfigBuilder: ScanConfigBuilder)
+ stream: InputStream,
+ startOffset: Option[Offset] = None,
+ endOffset: Option[Offset] = None)
extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat {
override def isStreaming: Boolean = true
@@ -99,8 +105,8 @@ case class StreamingDataSourceV2Relation(
// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {
case other: StreamingDataSourceV2Relation =>
- output == other.output && readSupport.getClass == other.readSupport.getClass &&
- options == other.options
+ output == other.output && source.getClass == other.source.getClass &&
+ options == other.options && startOffset == other.startOffset && endOffset == other.endOffset
case _ => false
}
@@ -108,24 +114,30 @@ case class StreamingDataSourceV2Relation(
Seq(output, source, options).hashCode()
}
- override def computeStats(): Statistics = readSupport match {
+ def createScan(): Scan = (startOffset, endOffset) match {
+ case (Some(start), Some(end)) =>
+ stream.asInstanceOf[MicroBatchInputStream].createMicroBatchScan(start, end)
+ case (Some(start), None) =>
+ stream.asInstanceOf[ContinuousInputStream].createContinuousScan(start)
+ case _ =>
+ throw new IllegalStateException("[BUG] wrong offsets in StreamingDataSourceV2Relation.")
+ }
+
+ override def computeStats(): Statistics = createScan() match {
case r: SupportsReportStatistics =>
- val statistics = r.estimateStatistics(scanConfigBuilder.build())
+ val statistics = r.estimateStatistics()
Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
- case _ =>
- Statistics(sizeInBytes = conf.defaultSizeInBytes)
+ case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes)
}
}
object DataSourceV2Relation {
private implicit class SourceHelpers(source: DataSourceV2) {
- def asReadSupportProvider: BatchReadSupportProvider = {
- source match {
- case provider: BatchReadSupportProvider =>
- provider
- case _ =>
- throw new AnalysisException(s"Data source is not readable: $name")
- }
+
+ def asFormat: Format = source match {
+ case f: Format => f
+ case _ =>
+ throw new AnalysisException(s"Data source is not readable: $name")
}
def asWriteSupportProvider: BatchWriteSupportProvider = {
@@ -146,15 +158,15 @@ object DataSourceV2Relation {
}
}
- def createReadSupport(
+ def getTable(
options: Map[String, String],
- userSpecifiedSchema: Option[StructType]): BatchReadSupport = {
+ userSpecifiedSchema: Option[StructType]): Table = {
val v2Options = new DataSourceOptions(options.asJava)
userSpecifiedSchema match {
case Some(s) =>
- asReadSupportProvider.createBatchReadSupport(s, v2Options)
+ asFormat.getTable(v2Options, s)
case _ =>
- asReadSupportProvider.createBatchReadSupport(v2Options)
+ asFormat.getTable(v2Options)
}
}
@@ -173,12 +185,17 @@ object DataSourceV2Relation {
source: DataSourceV2,
options: Map[String, String],
tableIdent: Option[TableIdentifier] = None,
- userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = {
- val readSupport = source.createReadSupport(options, userSpecifiedSchema)
- val output = readSupport.fullSchema().toAttributes
+ userSpecifiedSchema: Option[StructType] = None): Option[DataSourceV2Relation] = {
+ val table = source.getTable(options, userSpecifiedSchema)
+ val output = table.schema().toAttributes
val ident = tableIdent.orElse(tableFromOptions(options))
- DataSourceV2Relation(
- source, readSupport, output, options, ident, userSpecifiedSchema)
+ table match {
+ case batch: SupportsBatchRead =>
+ Some(DataSourceV2Relation(
+ source, batch, output, options, ident, userSpecifiedSchema))
+ case _ =>
+ None
+ }
}
private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 04a97735d024..743fd14174df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -26,18 +26,20 @@ import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeSta
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming._
/**
* Physical plan node for scanning data from a data source.
*/
case class DataSourceV2ScanExec(
- output: Seq[AttributeReference],
+ output: Seq[Attribute],
@transient source: DataSourceV2,
@transient options: Map[String, String],
@transient pushedFilters: Seq[Expression],
- @transient readSupport: ReadSupport,
- @transient scanConfig: ScanConfig)
+ @transient scan: Scan,
+ // `ProgressReporter` needs to know which stream a physical scan node associates to, so that
+ // it can collect metrics for a stream correctly.
+ @transient stream: Option[InputStream] = None)
extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan {
override def simpleString: String = "ScanV2 " + metadataString
@@ -45,33 +47,31 @@ case class DataSourceV2ScanExec(
// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {
case other: DataSourceV2ScanExec =>
- output == other.output && readSupport.getClass == other.readSupport.getClass &&
- options == other.options
+ output == other.output && source.getClass == other.source.getClass && options == other.options
case _ => false
}
override def hashCode(): Int = {
- Seq(output, source, options).hashCode()
+ Seq(output, source.getClass, options).hashCode()
}
- override def outputPartitioning: physical.Partitioning = readSupport match {
+ override def outputPartitioning: physical.Partitioning = scan match {
case _ if partitions.length == 1 =>
SinglePartition
case s: SupportsReportPartitioning =>
- new DataSourcePartitioning(
- s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name)))
+ new DataSourcePartitioning(s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
case _ => super.outputPartitioning
}
- private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig)
+ private lazy val partitions: Seq[InputPartition] = scan.planInputPartitions()
- private lazy val readerFactory = readSupport match {
- case r: BatchReadSupport => r.createReaderFactory(scanConfig)
- case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig)
- case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig)
- case _ => throw new IllegalStateException("unknown read support: " + readSupport)
+ private lazy val readerFactory = scan match {
+ case scan: BatchScan => scan.createReaderFactory()
+ case scan: MicroBatchScan => scan.createReaderFactory()
+ case scan: ContinuousScan => scan.createContinuousReaderFactory()
+ case _ => throw new IllegalStateException("unknown read support: " + scan)
}
// TODO: clean this up when we have dedicated scan plan for continuous streaming.
@@ -83,8 +83,8 @@ case class DataSourceV2ScanExec(
partitions.exists(readerFactory.supportColumnarReads)
}
- private lazy val inputRDD: RDD[InternalRow] = readSupport match {
- case _: ContinuousReadSupport =>
+ private lazy val inputRDD: RDD[InternalRow] = scan match {
+ case _: ContinuousScan =>
assert(!supportsBatch,
"continuous stream reader does not support columnar read yet.")
EpochCoordinatorRef.get(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 9a3109e7c199..42b448e80b8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -17,17 +17,22 @@
package org.apache.spark.sql.execution.datasources.v2
+import scala.collection.JavaConverters._
import scala.collection.mutable
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{sources, Strategy}
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition}
-import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, MicroBatchExecutionRelation, StreamingExecutionRelation}
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
+import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, MicroBatchInputStream}
object DataSourceV2Strategy extends Strategy {
@@ -102,7 +107,8 @@ object DataSourceV2Strategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
- val configBuilder = relation.readSupport.newScanConfigBuilder()
+ val dsOptions = new DataSourceOptions(relation.options.asJava)
+ val configBuilder = relation.table.newScanConfigBuilder(dsOptions)
// `pushedFilters` will be pushed down and evaluated in the underlying data sources.
// `postScanFilters` need to be evaluated after the scan.
// `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
@@ -121,8 +127,7 @@ object DataSourceV2Strategy extends Strategy {
relation.source,
relation.options,
pushedFilters,
- relation.readSupport,
- config)
+ relation.table.createBatchScan(config, dsOptions))
val filterCondition = postScanFilters.reduceLeftOption(And)
val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
@@ -130,13 +135,41 @@ object DataSourceV2Strategy extends Strategy {
// always add the projection, which will produce unsafe rows required by some operators
ProjectExec(project, withFilter) :: Nil
+ // Ideally `StreamingExecutionRelation`, `MicroBatchExecutionRelation` and
+ // `ContinuousExecutionRelation` are temporary and we don't need to handle them in strategy
+ // rules. However, the current streaming framework keeps a base logical plan instead of physical
+ // plan, so we need to do a temp query planning at the beginning to get operator pushdown
+ // result. Here we catch these temp logical plans, return fake physical plans to report the
+ // operator pushdown result.
+ case r: StreamingExecutionRelation =>
+ FakeStreamingScanExec(r.output) :: Nil
+
+ case r: MicroBatchExecutionRelation =>
+ val options = new DataSourceOptions(r.options.asJava)
+ val configBuilder = r.table.newScanConfigBuilder(options)
+ // TODO: operator pushdown
+ val config = configBuilder.build()
+ val stream = r.table.createMicroBatchInputStream(r.metadataPath, config, options)
+ FakeMicroBatchExec(r, stream, config.readSchema().toAttributes) :: Nil
+
+ case r: ContinuousExecutionRelation =>
+ val options = new DataSourceOptions(r.options.asJava)
+ val configBuilder = r.table.newScanConfigBuilder(options)
+ // TODO: operator pushdown
+ val config = configBuilder.build()
+ val stream = r.table.createContinuousInputStream(r.metadataPath, config, options)
+ FakeContinuousExec(r, stream, config.readSchema().toAttributes) :: Nil
+
case r: StreamingDataSourceV2Relation =>
- // TODO: support operator pushdown for streaming data sources.
- val scanConfig = r.scanConfigBuilder.build()
// ensure there is a projection, which will produce unsafe rows required by some operators
ProjectExec(r.output,
DataSourceV2ScanExec(
- r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil
+ r.output,
+ r.source,
+ r.options,
+ r.pushedFilters,
+ r.createScan(),
+ Some(r.stream))) :: Nil
case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil
@@ -149,7 +182,7 @@ object DataSourceV2Strategy extends Strategy {
case Repartition(1, false, child) =>
val isContinuous = child.find {
- case s: StreamingDataSourceV2Relation => s.readSupport.isInstanceOf[ContinuousReadSupport]
+ case s: StreamingDataSourceV2Relation => s.stream.isInstanceOf[ContinuousInputStream]
case _ => false
}.isDefined
@@ -162,3 +195,27 @@ object DataSourceV2Strategy extends Strategy {
case _ => Nil
}
}
+
+case class FakeStreamingScanExec(output: Seq[Attribute]) extends LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new IllegalStateException("cannot execute FakeStreamingScanExec")
+ }
+}
+
+case class FakeMicroBatchExec(
+ relation: MicroBatchExecutionRelation,
+ stream: MicroBatchInputStream,
+ output: Seq[Attribute]) extends LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new IllegalStateException("cannot execute FakeMicroBatchExec")
+ }
+}
+
+case class FakeContinuousExec(
+ relation: ContinuousExecutionRelation,
+ stream: ContinuousInputStream,
+ output: Seq[Attribute]) extends LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new IllegalStateException("cannot execute FakeContinuousExec")
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/NoopScanConfigBuilder.scala
similarity index 62%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/NoopScanConfigBuilder.scala
index 1be071614d92..56a5477d0c8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/NoopScanConfigBuilder.scala
@@ -15,26 +15,15 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.streaming
+package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder}
import org.apache.spark.sql.types.StructType
-/**
- * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to
- * carry schema and offsets for streaming data sources.
- */
-class SimpleStreamingScanConfigBuilder(
- schema: StructType,
- start: Offset,
- end: Option[Offset] = None)
- extends ScanConfigBuilder {
-
- override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end)
+class NoopScanConfigBuilder(schema: StructType) extends ScanConfigBuilder {
+ override def build(): ScanConfig = new NoopScanConfig(schema)
}
-case class SimpleStreamingScanConfig(
- readSchema: StructType,
- start: Offset,
- end: Option[Offset])
- extends ScanConfig
+class NoopScanConfig(schema: StructType) extends ScanConfig {
+ override def readSchema(): StructType = schema
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 2cac86599ef1..a99da5bd81bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -17,18 +17,20 @@
package org.apache.spark.sql.execution.streaming
+import java.util.IdentityHashMap
+
import scala.collection.JavaConverters._
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
-import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport}
+import org.apache.spark.sql.execution.datasources.v2.{FakeMicroBatchExec, StreamingDataSourceV2Relation, WriteToDataSourceV2}
+import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchInputStream}
import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, Offset => OffsetV2}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.{Clock, Utils}
@@ -49,9 +51,6 @@ class MicroBatchExecution(
@volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty
- private val readSupportToDataSourceMap =
- MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])]
-
private val triggerExecutor = trigger match {
case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
case OneTimeTrigger => OneTimeExecutor()
@@ -67,6 +66,7 @@ class MicroBatchExecution(
var nextSourceId = 0L
val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]()
val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]()
+ val v2ToMicroBatchExecutionMap = MutableMap[StreamingRelationV2, MicroBatchExecutionRelation]()
// We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a
// map as we go to ensure each identical relation gets the same StreamingExecutionRelation
// object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical
@@ -89,21 +89,18 @@ class MicroBatchExecution(
StreamingExecutionRelation(source, output)(sparkSession)
})
case s @ StreamingRelationV2(
- dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if
- !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) =>
- v2ToExecutionRelationMap.getOrElseUpdate(s, {
- // Materialize source to avoid creating it in every batch
+ sourceName, ds, table: SupportsMicroBatchRead, options, output, _)
+ if !disabledSources.contains(ds.getClass.getCanonicalName) =>
+ v2ToMicroBatchExecutionMap.getOrElseUpdate(s, {
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
- val readSupport = dataSourceV2.createMicroBatchReadSupport(
- metadataPath,
- new DataSourceOptions(options.asJava))
nextSourceId += 1
- readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options
- logInfo(s"Using MicroBatchReadSupport [$readSupport] from " +
- s"DataSourceV2 named '$sourceName' [$dataSourceV2]")
- StreamingExecutionRelation(readSupport, output)(sparkSession)
+ logInfo(s"Reading table [$table] from " +
+ s"DataSourceV2 named '$sourceName' [$ds]")
+ MicroBatchExecutionRelation(
+ sourceName, ds, table, output, metadataPath, options)(sparkSession)
})
- case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) =>
+ case s @ StreamingRelationV2(
+ sourceName, ds, _, _, output, v1Relation) =>
v2ToExecutionRelationMap.getOrElseUpdate(s, {
// Materialize source to avoid creating it in every batch
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
@@ -113,13 +110,58 @@ class MicroBatchExecution(
}
val source = v1Relation.get.dataSource.createSource(metadataPath)
nextSourceId += 1
- logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]")
+ logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$ds]")
StreamingExecutionRelation(source, output)(sparkSession)
})
}
- sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
+
+ // This is a temporary query planning, to get operator pushdown result of v2 sources.
+ // TODO: update the streaming engine to do query planning only once.
+ val relationToStream = new IdentityHashMap[MicroBatchExecutionRelation, MicroBatchInputStream]
+ createExecution(_logicalPlan, sparkSession).sparkPlan.foreach {
+ case exec: FakeMicroBatchExec =>
+ if (relationToStream.containsKey(exec.relation)) {
+ // This is a self-union/self-join, don't apply operator pushdown, since we want to keep
+ // one stream instance for the self-unioned/self-joined source.
+ // TODO: we can push down shared operators to the self-unioned/self-joined sources.
+ val options = new DataSourceOptions(exec.relation.options.asJava)
+ val configBuilder = exec.relation.table.newScanConfigBuilder(options)
+ val config = configBuilder.build()
+ val stream = exec.relation.table.createMicroBatchInputStream(
+ exec.relation.metadataPath, config, options)
+ relationToStream.put(exec.relation, stream)
+ } else {
+ relationToStream.put(exec.relation, exec.stream)
+ }
+
+ case _ =>
+ }
+
+ val finalPlan = _logicalPlan.transform {
+ case r: MicroBatchExecutionRelation =>
+ val stream = relationToStream.get(r)
+ assert(stream != null)
+ StreamingDataSourceV2Relation(r.output, r.ds, r.options, stream)
+ }
+
+ sources = finalPlan.collect {
+ case r: StreamingExecutionRelation => r.source
+ case r: StreamingDataSourceV2Relation => r.stream
+ }
uniqueSources = sources.distinct
- _logicalPlan
+
+ finalPlan
+ }
+
+ private def createExecution(plan: LogicalPlan, session: SparkSession): IncrementalExecution = {
+ new IncrementalExecution(
+ session,
+ plan,
+ outputMode,
+ checkpointFile("state"),
+ runId,
+ currentBatchId,
+ offsetSeqMetadata)
}
/**
@@ -341,7 +383,7 @@ class MicroBatchExecution(
reportTimeTaken("getOffset") {
(s, s.getOffset)
}
- case s: RateControlMicroBatchReadSupport =>
+ case s: RateControlMicroBatchInputStream =>
updateStatusMessage(s"Getting offsets from $s")
reportTimeTaken("latestOffset") {
val startOffset = availableOffsets
@@ -349,7 +391,7 @@ class MicroBatchExecution(
.getOrElse(s.initialOffset())
(s, Option(s.latestOffset(startOffset)))
}
- case s: MicroBatchReadSupport =>
+ case s: MicroBatchInputStream =>
updateStatusMessage(s"Getting offsets from $s")
reportTimeTaken("latestOffset") {
(s, Option(s.latestOffset()))
@@ -393,8 +435,8 @@ class MicroBatchExecution(
if (prevBatchOff.isDefined) {
prevBatchOff.get.toStreamProgress(sources).foreach {
case (src: Source, off) => src.commit(off)
- case (readSupport: MicroBatchReadSupport, off) =>
- readSupport.commit(readSupport.deserializeOffset(off.json))
+ case (stream: MicroBatchInputStream, off) =>
+ stream.commit(stream.deserializeOffset(off.json))
case (src, _) =>
throw new IllegalArgumentException(
s"Unknown source is found at constructNextBatch: $src")
@@ -439,39 +481,29 @@ class MicroBatchExecution(
logDebug(s"Retrieving data from $source: $current -> $available")
Some(source -> batch.logicalPlan)
- // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but
- // to be compatible with streaming source v1, we return a logical plan as a new batch here.
- case (readSupport: MicroBatchReadSupport, available)
- if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) =>
- val current = committedOffsets.get(readSupport).map {
- off => readSupport.deserializeOffset(off.json)
+ case (stream: MicroBatchInputStream, available)
+ if committedOffsets.get(stream).map(_ != available).getOrElse(true) =>
+ val current = committedOffsets.get(stream).map {
+ off => stream.deserializeOffset(off.json)
}
val endOffset: OffsetV2 = available match {
- case v1: SerializedOffset => readSupport.deserializeOffset(v1.json)
+ case v1: SerializedOffset => stream.deserializeOffset(v1.json)
case v2: OffsetV2 => v2
}
- val startOffset = current.getOrElse(readSupport.initialOffset)
- val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset)
- logDebug(s"Retrieving data from $readSupport: $current -> $endOffset")
-
- val (source, options) = readSupport match {
- // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2`
- // implementation. We provide a fake one here for explain.
- case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String]
- // Provide a fake value here just in case something went wrong, e.g. the reader gives
- // a wrong `equals` implementation.
- case _ => readSupportToDataSourceMap.getOrElse(readSupport, {
- FakeDataSourceV2 -> Map.empty[String, String]
- })
- }
- Some(readSupport -> StreamingDataSourceV2Relation(
- readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder))
+ val startOffset = current.getOrElse(stream.initialOffset)
+ logInfo(s"Retrieving data from $stream: $startOffset -> $endOffset")
+
+ // To be compatible with the v1 source, the `newData` is represented as a logical plan,
+ // while the `newData` of v2 source is just the start and end offsets. Here we return a
+ // fake logical plan to carry the offsets.
+ Some(stream -> OffsetHolder(startOffset, endOffset))
case _ => None
}
}
// Replace sources in the logical plan with data that has arrived since the last batch.
val newBatchesPlan = logicalPlan transform {
+ // For v1 sources.
case StreamingExecutionRelation(source, output) =>
newData.get(source).map { dataPlan =>
assert(output.size == dataPlan.output.size,
@@ -485,6 +517,15 @@ class MicroBatchExecution(
}.getOrElse {
LocalRelation(output, isStreaming = true)
}
+
+ // For v2 sources.
+ case r: StreamingDataSourceV2Relation =>
+ newData.get(r.stream).map {
+ case OffsetHolder(start, end) =>
+ r.copy(startOffset = Some(start), endOffset = Some(end))
+ }.getOrElse {
+ LocalRelation(r.output, isStreaming = true)
+ }
}
// Rewire the plan to use the new attributes that were returned by the source.
@@ -497,7 +538,7 @@ class MicroBatchExecution(
cd.dataType, cd.timeZoneId)
}
- val triggerLogicalPlan = sink match {
+ val planWithSink = sink match {
case _: Sink => newAttributePlan
case s: StreamingWriteSupportProvider =>
val writer = s.createStreamingWriteSupport(
@@ -515,14 +556,7 @@ class MicroBatchExecution(
StreamExecution.IS_CONTINUOUS_PROCESSING, false.toString)
reportTimeTaken("queryPlanning") {
- lastExecution = new IncrementalExecution(
- sparkSessionToRunBatch,
- triggerLogicalPlan,
- outputMode,
- checkpointFile("state"),
- runId,
- currentBatchId,
- offsetSeqMetadata)
+ lastExecution = createExecution(planWithSink, sparkSessionToRunBatch)
lastExecution.executedPlan // Force the lazy generation of execution plan
}
@@ -563,6 +597,6 @@ object MicroBatchExecution {
val BATCH_ID_KEY = "streaming.sql.batchId"
}
-object MemoryStreamDataSource extends DataSourceV2
-
-object FakeDataSourceV2 extends DataSourceV2
+case class OffsetHolder(start: OffsetV2, end: OffsetV2) extends LeafNode {
+ override def output: Seq[Attribute] = Nil
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 392229bcb5f5..78d50e6111e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
-import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation}
+import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchInputStream
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
import org.apache.spark.util.Clock
@@ -245,10 +245,12 @@ trait ProgressReporter extends Logging {
}
val onlyDataSourceV2Sources = {
- // Check whether the streaming query's logical plan has only V2 data sources
- val allStreamingLeaves =
- logicalPlan.collect { case s: StreamingExecutionRelation => s }
- allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] }
+ // Check whether the streaming query's logical plan has only V2 micro-batch data sources
+ val allStreamingLeaves = logicalPlan.collect {
+ case s: StreamingDataSourceV2Relation => s.stream.isInstanceOf[MicroBatchInputStream]
+ case _: StreamingExecutionRelation => false
+ }
+ allStreamingLeaves.forall(_ == true)
}
if (onlyDataSourceV2Sources) {
@@ -256,9 +258,9 @@ trait ProgressReporter extends Logging {
// (can happen with self-unions or self-joins). This means the source is scanned multiple
// times in the query, we should count the numRows for each scan.
val sourceToInputRowsTuples = lastExecution.executedPlan.collect {
- case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] =>
+ case s: DataSourceV2ScanExec if s.stream.isDefined =>
val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
- val source = s.readSupport.asInstanceOf[BaseStreamingSource]
+ val source = s.stream.get
source -> numRows
}
logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index 4b696dfa5735..b603f3d057b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.v2._
object StreamingRelation {
def apply(dataSource: DataSource): StreamingRelation = {
@@ -81,6 +82,54 @@ case class StreamingExecutionRelation(
override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
}
+case class MicroBatchExecutionRelation(
+ source: String,
+ ds: DataSourceV2,
+ table: SupportsMicroBatchRead,
+ output: Seq[Attribute],
+ metadataPath: String,
+ options: Map[String, String])(session: SparkSession)
+ extends LeafNode with MultiInstanceRelation {
+
+ override def otherCopyArgs: Seq[AnyRef] = session :: Nil
+ override def isStreaming: Boolean = true
+ override def toString: String = source
+
+ // There's no sensible value here. On the execution path, this relation will be swapped out with
+ // `StreamingDataSourceV2Relation`. But some dataframe operations (in particular explain) do lead
+ // to this node surviving analysis. So we satisfy the LeafNode contract with the session default
+ // value.
+ override def computeStats(): Statistics = Statistics(
+ sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
+ )
+
+ override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))(session)
+}
+
+case class ContinuousExecutionRelation(
+ source: String,
+ ds: DataSourceV2,
+ table: SupportsContinuousRead,
+ output: Seq[Attribute],
+ metadataPath: String,
+ options: Map[String, String])(session: SparkSession)
+ extends LeafNode with MultiInstanceRelation {
+
+ override def otherCopyArgs: Seq[AnyRef] = session :: Nil
+ override def isStreaming: Boolean = true
+ override def toString: String = source
+
+ // There's no sensible value here. On the execution path, this relation will be swapped out with
+ // `StreamingDataSourceV2Relation`. But some dataframe operations (in particular explain) do lead
+ // to this node surviving analysis. So we satisfy the LeafNode contract with the session default
+ // value.
+ override def computeStats(): Statistics = Statistics(
+ sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
+ )
+
+ override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))(session)
+}
+
// We have to pack in the V1 data source as a shim, for the case when a source implements
// continuous processing (which is always V2) but only has V1 microbatch support. We don't
// know at read time whether the query is continuous or not, so we need to be able to
@@ -92,8 +141,9 @@ case class StreamingExecutionRelation(
* and should be converted before passing to [[StreamExecution]].
*/
case class StreamingRelationV2(
- dataSource: DataSourceV2,
sourceName: String,
+ dataSource: DataSourceV2,
+ table: Table,
extraOptions: Map[String, String],
output: Seq[Attribute],
v1Relation: Option[StreamingRelation])(session: SparkSession)
@@ -109,30 +159,6 @@ case class StreamingRelationV2(
override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
}
-/**
- * Used to link a [[DataSourceV2]] into a continuous processing execution.
- */
-case class ContinuousExecutionRelation(
- source: ContinuousReadSupportProvider,
- extraOptions: Map[String, String],
- output: Seq[Attribute])(session: SparkSession)
- extends LeafNode with MultiInstanceRelation {
-
- override def otherCopyArgs: Seq[AnyRef] = session :: Nil
- override def isStreaming: Boolean = true
- override def toString: String = source.toString
-
- // There's no sensible value here. On the execution path, this relation will be
- // swapped out with microbatches. But some dataframe operations (in particular explain) do lead
- // to this node surviving analysis. So we satisfy the LeafNode contract with the session default
- // value.
- override def computeStats(): Statistics = Statistics(
- sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
- )
-
- override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
-}
-
/**
* A dummy physical plan for [[StreamingRelation]] to support
* [[org.apache.spark.sql.Dataset.explain]]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index f009c52449ad..edeb189886c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -17,25 +17,26 @@
package org.apache.spark.sql.execution.streaming.continuous
+import java.util.IdentityHashMap
import java.util.UUID
import java.util.concurrent.TimeUnit
import java.util.function.UnaryOperator
import scala.collection.JavaConverters._
-import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
+import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.SparkEnv
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp}
+import org.apache.spark.sql.catalyst.expressions.{CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation}
-import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _}
+import org.apache.spark.sql.execution.datasources.v2.{FakeContinuousExec, StreamingDataSourceV2Relation}
+import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _}
import org.apache.spark.sql.sources.v2
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider}
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider, SupportsContinuousRead}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, InputStream, PartitionOffset}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
-import org.apache.spark.util.{Clock, Utils}
+import org.apache.spark.util.Clock
class ContinuousExecution(
sparkSession: SparkSession,
@@ -52,25 +53,74 @@ class ContinuousExecution(
sparkSession, name, checkpointRoot, analyzedPlan, sink,
trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
- @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq()
- override protected def sources: Seq[BaseStreamingSource] = continuousSources
+ @volatile protected var sources: Seq[InputStream] = Seq.empty
// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _
override val logicalPlan: LogicalPlan = {
- val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]()
- analyzedPlan.transform {
- case r @ StreamingRelationV2(
- source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) =>
- // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration?
- toExecutionRelationMap.getOrElseUpdate(r, {
- ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession)
+ val v2ToContinuousExecutionMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]()
+ var nextSourceId = 0
+ val _logicalPlan = analyzedPlan.transform {
+ case s @ StreamingRelationV2(
+ sourceName, ds, table: SupportsContinuousRead, options, output, _) =>
+ v2ToContinuousExecutionMap.getOrElseUpdate(s, {
+ val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
+ nextSourceId += 1
+ ContinuousExecutionRelation(
+ sourceName, ds, table, output, metadataPath, options)(sparkSession)
})
- case StreamingRelationV2(_, sourceName, _, _, _) =>
+ case r: StreamingRelationV2 =>
throw new UnsupportedOperationException(
- s"Data source $sourceName does not support continuous processing.")
+ s"Data source ${r.sourceName} does not support continuous processing.")
+ }
+
+ // This is a temporary query planning, to get operator pushdown result of v2 sources.
+ // TODO: update the streaming engine to do query planning only once.
+ val relationToStream = new IdentityHashMap[ContinuousExecutionRelation, ContinuousInputStream]
+ createExecution(_logicalPlan, sparkSession).sparkPlan.foreach {
+ case exec: FakeContinuousExec =>
+ if (relationToStream.containsKey(exec.relation)) {
+ // This is a self-union/self-join, don't apply operator pushdown, since we want to keep
+ // one stream instance for the self-unioned/self-joined source.
+ // TODO: we can push down shared operators to the self-unioned/self-joined sources.
+ val options = new DataSourceOptions(exec.relation.options.asJava)
+ val configBuilder = exec.relation.table.newScanConfigBuilder(options)
+ val config = configBuilder.build()
+ val stream = exec.relation.table.createContinuousInputStream(
+ exec.relation.metadataPath, config, options)
+ relationToStream.put(exec.relation, stream)
+ } else {
+ relationToStream.put(exec.relation, exec.stream)
+ }
+
+ case _ =>
}
+
+ val finalPlan = _logicalPlan.transform {
+ case r: ContinuousExecutionRelation =>
+ val stream = relationToStream.get(r)
+ assert(stream != null)
+ StreamingDataSourceV2Relation(r.output, r.ds, r.options, stream)
+ }
+
+ sources = finalPlan.collect {
+ case r: StreamingDataSourceV2Relation => r.stream
+ }
+ uniqueSources = sources.distinct
+
+ finalPlan
+ }
+
+ private def createExecution(plan: LogicalPlan, session: SparkSession): IncrementalExecution = {
+ new IncrementalExecution(
+ session,
+ plan,
+ outputMode,
+ checkpointFile("state"),
+ runId,
+ currentBatchId,
+ offsetSeqMetadata)
}
private val triggerExecutor = trigger match {
@@ -90,6 +140,8 @@ class ContinuousExecution(
do {
runContinuous(sparkSessionForStream)
} while (state.updateAndGet(stateUpdate) == ACTIVE)
+
+ stopSources()
}
/**
@@ -130,7 +182,7 @@ class ContinuousExecution(
// We are starting this stream for the first time. Offsets are all None.
logInfo(s"Starting new streaming query.")
currentBatchId = 0
- OffsetSeq.fill(continuousSources.map(_ => null): _*)
+ OffsetSeq.fill(sources.map(_ => null): _*)
}
}
@@ -139,47 +191,17 @@ class ContinuousExecution(
* @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
*/
private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
- // A list of attributes that will need to be updated.
- val replacements = new ArrayBuffer[(Attribute, Attribute)]
- // Translate from continuous relation to the underlying data source.
- var nextSourceId = 0
- continuousSources = logicalPlan.collect {
- case ContinuousExecutionRelation(dataSource, extraReaderOptions, output) =>
- val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
- nextSourceId += 1
-
- dataSource.createContinuousReadSupport(
- metadataPath,
- new DataSourceOptions(extraReaderOptions.asJava))
- }
- uniqueSources = continuousSources.distinct
-
val offsets = getStartOffsets(sparkSessionForQuery)
- var insertedSourceId = 0
- val withNewSources = logicalPlan transform {
- case ContinuousExecutionRelation(source, options, output) =>
- val readSupport = continuousSources(insertedSourceId)
- insertedSourceId += 1
- val newOutput = readSupport.fullSchema().toAttributes
-
- assert(output.size == newOutput.size,
- s"Invalid reader: ${Utils.truncatedString(output, ",")} != " +
- s"${Utils.truncatedString(newOutput, ",")}")
- replacements ++= output.zip(newOutput)
-
+ val withNewSources: LogicalPlan = logicalPlan transform {
+ case relation: StreamingDataSourceV2Relation =>
val loggedOffset = offsets.offsets(0)
- val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json))
- val startOffset = realOffset.getOrElse(readSupport.initialOffset)
- val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset)
- StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder)
+ val realOffset = loggedOffset.map(off => relation.stream.deserializeOffset(off.json))
+ val startOffset = realOffset.getOrElse(relation.stream.initialOffset)
+ relation.copy(startOffset = Some(startOffset))
}
- // Rewire the plan to use the new attributes that were returned by the source.
- val replacementMap = AttributeMap(replacements)
- val triggerLogicalPlan = withNewSources transformAllExpressions {
- case a: Attribute if replacementMap.contains(a) =>
- replacementMap(a).withMetadata(a.metadata)
+ withNewSources transformAllExpressions {
case (_: CurrentTimestamp | _: CurrentDate) =>
throw new IllegalStateException(
"CurrentTimestamp and CurrentDate not yet supported for continuous processing")
@@ -187,26 +209,19 @@ class ContinuousExecution(
val writer = sink.createStreamingWriteSupport(
s"$runId",
- triggerLogicalPlan.schema,
+ withNewSources.schema,
outputMode,
new DataSourceOptions(extraOptions.asJava))
- val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan)
+ val planWithSink = WriteToContinuousDataSource(writer, withNewSources)
reportTimeTaken("queryPlanning") {
- lastExecution = new IncrementalExecution(
- sparkSessionForQuery,
- withSink,
- outputMode,
- checkpointFile("state"),
- runId,
- currentBatchId,
- offsetSeqMetadata)
+ lastExecution = createExecution(planWithSink, sparkSessionForQuery)
lastExecution.executedPlan // Force the lazy generation of execution plan
}
- val (readSupport, scanConfig) = lastExecution.executedPlan.collect {
- case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] =>
- scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig
+ val stream = planWithSink.collect {
+ case relation: StreamingDataSourceV2Relation =>
+ relation.stream.asInstanceOf[ContinuousInputStream]
}.head
sparkSessionForQuery.sparkContext.setLocalProperty(
@@ -226,16 +241,14 @@ class ContinuousExecution(
// Use the parent Spark session for the endpoint since it's where this query ID is registered.
val epochEndpoint =
EpochCoordinatorRef.create(
- writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get)
+ writer, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get)
val epochUpdateThread = new Thread(new Runnable {
override def run: Unit = {
try {
triggerExecutor.execute(() => {
startTrigger()
- val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) &&
- state.compareAndSet(ACTIVE, RECONFIGURING)
- if (shouldReconfigure) {
+ if (stream.needsReconfiguration && state.compareAndSet(ACTIVE, RECONFIGURING)) {
if (queryExecutionThread.isAlive) {
queryExecutionThread.interrupt()
}
@@ -276,7 +289,6 @@ class ContinuousExecution(
epochUpdateThread.interrupt()
epochUpdateThread.join()
- stopSources()
sparkSession.sparkContext.cancelJobGroup(runId.toString)
}
}
@@ -286,11 +298,11 @@ class ContinuousExecution(
*/
def addOffset(
epoch: Long,
- readSupport: ContinuousReadSupport,
+ stream: ContinuousInputStream,
partitionOffsets: Seq[PartitionOffset]): Unit = {
- assert(continuousSources.length == 1, "only one continuous source supported currently")
+ assert(sources.length == 1, "only one continuous source supported currently")
- val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray)
+ val globalOffset = stream.mergeOffsets(partitionOffsets.toArray)
val oldOffset = synchronized {
offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
offsetLog.get(epoch - 1)
@@ -314,7 +326,7 @@ class ContinuousExecution(
* before this is called.
*/
def commit(epoch: Long): Unit = {
- assert(continuousSources.length == 1, "only one continuous source supported currently")
+ assert(sources.length == 1, "only one continuous source supported currently")
assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit")
synchronized {
@@ -323,9 +335,9 @@ class ContinuousExecution(
if (queryExecutionThread.isAlive) {
commitLog.add(epoch, CommitMetadata())
val offset =
- continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json)
- committedOffsets ++= Seq(continuousSources(0) -> offset)
- continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset])
+ sources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json)
+ committedOffsets ++= Seq(sources(0) -> offset)
+ sources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset])
} else {
return
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index a6cde2b8a710..3b6201049d4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -22,17 +22,16 @@ import org.json4s.jackson.Serialization
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair}
+import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming._
-import org.apache.spark.sql.types.StructType
case class RateStreamPartitionOffset(
partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset
-class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport {
+class RateStreamContinuousInputStream(options: DataSourceOptions) extends ContinuousInputStream {
implicit val defaultFormats: DefaultFormats = DefaultFormats
val creationTime = System.currentTimeMillis()
@@ -54,18 +53,36 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
}
- override def fullSchema(): StructType = RateStreamProvider.SCHEMA
+ override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime)
- override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
- new SimpleStreamingScanConfigBuilder(fullSchema(), start)
+ override def createContinuousScan(start: Offset): ContinuousScan = {
+ new RateStreamContinuousScan(numPartitions, perPartitionRate, start)
}
- override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime)
+ override def commit(end: Offset): Unit = {}
+ override def stop(): Unit = {}
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start
+ private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
+ RateStreamOffset(Range(0, numPartitions).map { i =>
+ // Note that the starting offset is exclusive, so we have to decrement the starting value by
+ // the increment that will later be applied. The first row output in each partition will have
+ // a value equal to the partition index.
+ (i, ValueRunTimeMsPair((i - numPartitions).toLong, creationTimeMs))
+ }.toMap)
+ }
+}
+
+class RateStreamContinuousScan(
+ numPartitions: Int,
+ perPartitionRate: Double,
+ start: Offset) extends ContinuousScan {
+
+ override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = {
+ RateStreamContinuousReaderFactory
+ }
- val partitionStartMap = startOffset match {
+ override def planInputPartitions(): Array[InputPartition] = {
+ val partitionStartMap = start match {
case off: RateStreamOffset => off.partitionToValueAndRunTimeMs
case off =>
throw new IllegalArgumentException(
@@ -74,8 +91,8 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin
if (partitionStartMap.keySet.size != numPartitions) {
throw new IllegalArgumentException(
s"The previous run contained ${partitionStartMap.keySet.size} partitions, but" +
- s" $numPartitions partitions are currently configured. The numPartitions option" +
- " cannot be changed.")
+ s" $numPartitions partitions are currently configured. The numPartitions option" +
+ " cannot be changed.")
}
Range(0, numPartitions).map { i =>
@@ -90,28 +107,6 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin
perPartitionRate)
}.toArray
}
-
- override def createContinuousReaderFactory(
- config: ScanConfig): ContinuousPartitionReaderFactory = {
- RateStreamContinuousReaderFactory
- }
-
- override def commit(end: Offset): Unit = {}
- override def stop(): Unit = {}
-
- private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
- RateStreamOffset(
- Range(0, numPartitions).map { i =>
- // Note that the starting offset is exclusive, so we have to decrement the starting value
- // by the increment that will later be applied. The first row output in each
- // partition will have a value equal to the partition index.
- (i,
- ValueRunTimeMsPair(
- (i - numPartitions).toLong,
- creationTimeMs))
- }.toMap)
- }
-
}
case class RateStreamContinuousInputPartition(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
index 28ab2448a663..38b66a172e5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
@@ -38,20 +38,18 @@ import org.apache.spark.sql.execution.streaming.sources.TextSocketReader
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming._
-import org.apache.spark.sql.types.StructType
import org.apache.spark.util.RpcUtils
-
/**
- * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials
- * and debugging. This ContinuousReadSupport will *not* work in production applications due to
+ * A ContinuousInputStream that reads text lines through a TCP socket, designed only for tutorials
+ * and debugging. This ContinuousInputStream will *not* work in production applications due to
* multiple reasons, including no support for fault recovery.
*
* The driver maintains a socket connection to the host-port, keeps the received messages in
* buckets and serves the messages to the executors via a RPC endpoint.
*/
-class TextSocketContinuousReadSupport(options: DataSourceOptions)
- extends ContinuousReadSupport with Logging {
+class TextSocketContinuousInputStream(options: DataSourceOptions)
+ extends ContinuousInputStream with Logging {
implicit val defaultFormats: DefaultFormats = DefaultFormats
@@ -60,7 +58,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
assert(SparkSession.getActiveSession.isDefined)
private val spark = SparkSession.getActiveSession.get
- private val numPartitions = spark.sparkContext.defaultParallelism
+ private val numPartitions: Int = spark.sparkContext.defaultParallelism
@GuardedBy("this")
private var socket: Socket = _
@@ -101,21 +99,8 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
startOffset
}
- override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
- new SimpleStreamingScanConfigBuilder(fullSchema(), start)
- }
-
- override def fullSchema(): StructType = {
- if (includeTimestamp) {
- TextSocketReader.SCHEMA_TIMESTAMP
- } else {
- TextSocketReader.SCHEMA_REGULAR
- }
- }
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val startOffset = config.asInstanceOf[SimpleStreamingScanConfig]
- .start.asInstanceOf[TextSocketOffset]
+ override def createContinuousScan(start: Offset): ContinuousScan = {
+ val startOffset = start.asInstanceOf[TextSocketOffset]
recordEndpoint.setStartOffsets(startOffset.offsets)
val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}"
endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
@@ -134,15 +119,12 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
" cannot be changed.")
}
- startOffset.offsets.zipWithIndex.map {
+ val partitions: Array[InputPartition] = startOffset.offsets.zipWithIndex.map {
case (offset, i) =>
TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp)
}.toArray
- }
- override def createContinuousReaderFactory(
- config: ScanConfig): ContinuousPartitionReaderFactory = {
- TextSocketReaderFactory
+ new TextSocketContinuousScan(partitions)
}
override def commit(end: Offset): Unit = synchronized {
@@ -157,7 +139,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
val max = startOffset.offsets(partition) + buckets(partition).size
if (offset > max) {
throw new IllegalStateException("Invalid offset " + offset + " to commit" +
- " for partition " + partition + ". Max valid offset: " + max)
+ " for partition " + partition + ". Max valid offset: " + max)
}
val n = offset - startOffset.offsets(partition)
buckets(partition).trimStart(n)
@@ -197,7 +179,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
logWarning(s"Stream closed by $host:$port")
return
}
- TextSocketContinuousReadSupport.this.synchronized {
+ TextSocketContinuousInputStream.this.synchronized {
currentOffset += 1
val newData = (line,
Timestamp.valueOf(
@@ -218,9 +200,20 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
override def toString: String = s"TextSocketContinuousReader[host: $host, port: $port]"
private def includeTimestamp: Boolean = options.getBoolean("includeTimestamp", false)
+}
+
+class TextSocketContinuousScan(partitions: Array[InputPartition]) extends ContinuousScan {
+
+ override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = {
+ TextSocketContinuousReaderFactory
+ }
+ override def planInputPartitions(): Array[InputPartition] = {
+ partitions
+ }
}
+
/**
* Continuous text socket input partition.
*/
@@ -231,7 +224,7 @@ case class TextSocketContinuousInputPartition(
includeTimestamp: Boolean) extends InputPartition
-object TextSocketReaderFactory extends ContinuousPartitionReaderFactory {
+object TextSocketContinuousReaderFactory extends ContinuousPartitionReaderFactory {
override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
val p = partition.asInstanceOf[TextSocketContinuousInputPartition]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index 2238ce26e7b4..e4ceeef28ca1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
import org.apache.spark.util.RpcUtils
@@ -83,14 +83,14 @@ private[sql] object EpochCoordinatorRef extends Logging {
*/
def create(
writeSupport: StreamingWriteSupport,
- readSupport: ContinuousReadSupport,
+ inputStream: ContinuousInputStream,
query: ContinuousExecution,
epochCoordinatorId: String,
startEpoch: Long,
session: SparkSession,
env: SparkEnv): RpcEndpointRef = synchronized {
val coordinator = new EpochCoordinator(
- writeSupport, readSupport, query, startEpoch, session, env.rpcEnv)
+ writeSupport, inputStream, query, startEpoch, session, env.rpcEnv)
val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator)
logInfo("Registered EpochCoordinator endpoint")
ref
@@ -116,7 +116,7 @@ private[sql] object EpochCoordinatorRef extends Logging {
*/
private[continuous] class EpochCoordinator(
writeSupport: StreamingWriteSupport,
- readSupport: ContinuousReadSupport,
+ inputStream: ContinuousInputStream,
query: ContinuousExecution,
startEpoch: Long,
session: SparkSession,
@@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator(
partitionOffsets.collect { case ((e, _), o) if e == epoch => o }
if (thisEpochOffsets.size == numReaderPartitions) {
logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets")
- query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq)
+ query.addOffset(epoch, inputStream, thisEpochOffsets.toSeq)
resolveCommitsAtEpoch(epoch)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index adf52aba21a0..03549e72b625 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -28,11 +28,12 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, InputStream, MicroBatchInputStream, MicroBatchScan, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@@ -45,11 +46,17 @@ object MemoryStream {
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
}
+// This class is used to indicate the memory stream data source. We don't actually use it, as
+// memory stream is for test only and we never look it up by name.
+object MemoryStreamSource extends DataSourceV2
+
/**
* A base class for memory stream implementations. Supports adding data and resetting.
*/
-abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource {
- protected val encoder = encoderFor[A]
+abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext)
+ extends BaseStreamingSource with InputStream {
+
+ val encoder = encoderFor[A]
protected val attributes = encoder.schema.toAttributes
def toDS(): Dataset[A] = {
@@ -64,24 +71,44 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas
addData(data.toTraversable)
}
- def fullSchema(): StructType = encoder.schema
-
- protected def logicalPlan: LogicalPlan
+ protected val logicalPlan = StreamingRelationV2(
+ "memory",
+ MemoryStreamSource,
+ new MemoryStreamTable(this),
+ Map.empty,
+ attributes,
+ None)(sqlContext.sparkSession)
def addData(data: TraversableOnce[A]): Offset
}
+class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table
+ with SupportsMicroBatchRead with SupportsContinuousRead {
+
+ override def schema(): StructType = stream.encoder.schema
+
+ override def createMicroBatchInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): MicroBatchInputStream = {
+ stream.asInstanceOf[MicroBatchInputStream]
+ }
+
+ override def createContinuousInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): ContinuousInputStream = {
+ stream.asInstanceOf[ContinuousInputStream]
+ }
+}
+
/**
* A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]]
* is intended for use in unit tests as it can only replay data when the object is still
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
- extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging {
-
- protected val logicalPlan: LogicalPlan =
- StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
- protected val output = logicalPlan.output
+ extends MemoryStreamBase[A](sqlContext) with MicroBatchInputStream with Logging {
/**
* All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
@@ -117,7 +144,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
}
- override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"
+ override def toString: String = s"MemoryStream[${Utils.truncatedString(logicalPlan.output, ",")}]"
override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
@@ -127,15 +154,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
if (currentOffset.offset == -1) null else currentOffset
}
- override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = {
- new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end))
- }
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val sc = config.asInstanceOf[SimpleStreamingScanConfig]
- val startOffset = sc.start.asInstanceOf[LongOffset]
- val endOffset = sc.end.get.asInstanceOf[LongOffset]
- synchronized {
+ override def createMicroBatchScan(start: OffsetV2, end: OffsetV2): MicroBatchScan = {
+ val startOffset = start.asInstanceOf[LongOffset]
+ val endOffset = end.asInstanceOf[LongOffset]
+ val partitions: Array[InputPartition] = synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal = startOffset.offset.toInt + 1
val endOrdinal = endOffset.offset.toInt + 1
@@ -154,10 +176,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
new MemoryStreamInputPartition(block)
}.toArray
}
- }
-
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- MemoryStreamReaderFactory
+ new MemoryStreamMicroBatchScan(partitions)
}
private def generateDebugString(
@@ -199,6 +218,16 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
}
+class MemoryStreamMicroBatchScan(partitions: Array[InputPartition]) extends MicroBatchScan {
+
+ override def createReaderFactory(): PartitionReaderFactory = {
+ MemoryStreamReaderFactory
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ partitions
+ }
+}
class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index dbcc4483e577..097461976172 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -30,9 +30,10 @@ import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.{Encoder, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.streaming.{Offset => _, _}
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions}
-import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, Format, SupportsContinuousRead, Table}
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig}
import org.apache.spark.sql.sources.v2.reader.streaming._
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.RpcUtils
/**
@@ -44,16 +45,10 @@ import org.apache.spark.util.RpcUtils
* the specified offset within the list, or null if that offset doesn't yet have a record.
*/
class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
- extends MemoryStreamBase[A](sqlContext)
- with ContinuousReadSupportProvider with ContinuousReadSupport {
+ extends MemoryStreamBase[A](sqlContext) with ContinuousInputStream with Format {
private implicit val formats = Serialization.formats(NoTypeHints)
- protected val logicalPlan =
- StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession)
-
- // ContinuousReader implementation
-
@GuardedBy("this")
private val records = Seq.fill(numPartitions)(new ListBuffer[A])
@@ -86,14 +81,9 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
)
}
- override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
- new SimpleStreamingScanConfigBuilder(fullSchema(), start)
- }
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val startOffset = config.asInstanceOf[SimpleStreamingScanConfig]
- .start.asInstanceOf[ContinuousMemoryStreamOffset]
- synchronized {
+ override def createContinuousScan(start: Offset): ContinuousScan = {
+ val startOffset = start.asInstanceOf[ContinuousMemoryStreamOffset]
+ val partitions: Array[InputPartition] = synchronized {
val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
endpointRef =
recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
@@ -102,11 +92,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index)
}.toArray
}
- }
-
- override def createContinuousReaderFactory(
- config: ScanConfig): ContinuousPartitionReaderFactory = {
- ContinuousMemoryStreamReaderFactory
+ new MemoryStreamContinuousScan(partitions)
}
override def stop(): Unit = {
@@ -115,11 +101,33 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
override def commit(end: Offset): Unit = {}
- // ContinuousReadSupportProvider implementation
+ // Format implementation
// This is necessary because of how StreamTest finds the source for AddDataMemory steps.
- override def createContinuousReadSupport(
- checkpointLocation: String,
- options: DataSourceOptions): ContinuousReadSupport = this
+ override def getTable(options: DataSourceOptions): Table = {
+ new Table with SupportsContinuousRead {
+ override def schema(): StructType = {
+ ContinuousMemoryStream.this.encoder.schema
+ }
+
+ def createContinuousInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): ContinuousInputStream = {
+ ContinuousMemoryStream.this
+ }
+ }
+ }
+}
+
+class MemoryStreamContinuousScan(partitions: Array[InputPartition]) extends ContinuousScan {
+
+ override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = {
+ ContinuousMemoryStreamReaderFactory
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ partitions
+ }
}
object ContinuousMemoryStream {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchInputStream.scala
similarity index 87%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchInputStream.scala
index 90680ea38fbd..d7c32de10968 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchInputStream.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.execution.streaming.sources
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, Offset}
-// A special `MicroBatchReadSupport` that can get latestOffset with a start offset.
-trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport {
+// A special `MicroBatchInputStream` that can get latestOffset with a start offset.
+trait RateControlMicroBatchInputStream extends MicroBatchInputStream {
override def latestOffset(): Offset = {
throw new IllegalAccessException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchInputStream.scala
similarity index 84%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchInputStream.scala
index f5364047adff..3488a5f7d887 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchInputStream.scala
@@ -31,12 +31,11 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, MicroBatchScan, Offset}
import org.apache.spark.util.{ManualClock, SystemClock}
-class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String)
- extends MicroBatchReadSupport with Logging {
+class RateStreamMicroBatchInputStream(options: DataSourceOptions, checkpointLocation: String)
+ extends MicroBatchInputStream with Logging {
import RateStreamProvider._
private[sources] val clock = {
@@ -60,6 +59,14 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca
s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
}
+ private val numPartitions = {
+ val activeSession = SparkSession.getActiveSession
+ require(activeSession.isDefined)
+ Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String]))
+ .map(_.toInt)
+ .getOrElse(activeSession.get.sparkContext.defaultParallelism)
+ }
+
private[sources] val creationTimeMs = {
val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession)
require(session.isDefined)
@@ -70,7 +77,7 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca
val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
writer.write("v" + VERSION + "\n")
writer.write(metadata.json)
- writer.flush
+ writer.flush()
}
override def deserialize(in: InputStream): LongOffset = {
@@ -117,16 +124,44 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca
LongOffset(json.toLong)
}
- override def fullSchema(): StructType = SCHEMA
+ override def createMicroBatchScan(start: Offset, end: Offset): MicroBatchScan = {
+ new RateSteamMicroBatchScan(
+ maxSeconds,
+ rowsPerSecond,
+ creationTimeMs,
+ rampUpTimeSeconds,
+ numPartitions,
+ start.asInstanceOf[LongOffset], end.asInstanceOf[LongOffset])
+ }
+
+ override def commit(end: Offset): Unit = {}
+
+ override def stop(): Unit = {}
+
+ override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " +
+ s"rampUpTimeSeconds=$rampUpTimeSeconds, " +
+ s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
+}
+
+class RateSteamMicroBatchScan(
+ maxSeconds: Long,
+ rowsPerSecond: Long,
+ creationTimeMs: Long,
+ rampUpTimeSeconds: Long,
+ numPartitions: Int,
+ start: LongOffset,
+ end: LongOffset) extends MicroBatchScan with Logging {
+ import RateStreamProvider._
+
+ @volatile private var lastTimeMs: Long = creationTimeMs
- override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = {
- new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end))
+ override def createReaderFactory(): PartitionReaderFactory = {
+ RateStreamMicroBatchReaderFactory
}
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val sc = config.asInstanceOf[SimpleStreamingScanConfig]
- val startSeconds = sc.start.asInstanceOf[LongOffset].offset
- val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset
+ override def planInputPartitions(): Array[InputPartition] = {
+ val startSeconds = start.offset
+ val endSeconds = end.offset
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
if (endSeconds > maxSeconds) {
throw new ArithmeticException("Integer overflow. Max offset with " +
@@ -148,31 +183,12 @@ class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLoca
val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
val relativeMsPerValue =
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
- val numPartitions = {
- val activeSession = SparkSession.getActiveSession
- require(activeSession.isDefined)
- Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String]))
- .map(_.toInt)
- .getOrElse(activeSession.get.sparkContext.defaultParallelism)
- }
(0 until numPartitions).map { p =>
- new RateStreamMicroBatchInputPartition(
+ RateStreamMicroBatchInputPartition(
p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
}.toArray
}
-
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- RateStreamMicroBatchReaderFactory
- }
-
- override def commit(end: Offset): Unit = {}
-
- override def stop(): Unit = {}
-
- override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " +
- s"rampUpTimeSeconds=$rampUpTimeSeconds, " +
- s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
}
case class RateStreamMicroBatchInputPartition(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
index 6942dfbfe0ec..bfb2a7ad2afb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
@@ -18,10 +18,11 @@
package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport
+import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousInputStream
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.ScanConfig
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, MicroBatchInputStream, MicroBatchScan}
import org.apache.spark.sql.types._
/**
@@ -38,14 +39,16 @@ import org.apache.spark.sql.types._
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
*/
-class RateStreamProvider extends DataSourceV2
- with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister {
+class RateStreamProvider extends Format with DataSourceRegister {
import RateStreamProvider._
- override def createMicroBatchReadSupport(
- checkpointLocation: String,
- options: DataSourceOptions): MicroBatchReadSupport = {
- if (options.get(ROWS_PER_SECOND).isPresent) {
+ override def getTable(options: DataSourceOptions): Table = {
+ validateOptions(options)
+ RateStreamTable
+ }
+
+ private def validateOptions(options: DataSourceOptions): Unit = {
+ if (options.get(ROWS_PER_SECOND).isPresent) {
val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong
if (rowsPerSecond <= 0) {
throw new IllegalArgumentException(
@@ -69,17 +72,29 @@ class RateStreamProvider extends DataSourceV2
s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive")
}
}
-
- new RateStreamMicroBatchReadSupport(options, checkpointLocation)
}
- override def createContinuousReadSupport(
- checkpointLocation: String,
- options: DataSourceOptions): ContinuousReadSupport = {
- new RateStreamContinuousReadSupport(options)
+ override def shortName(): String = "rate"
+}
+
+object RateStreamTable extends Table
+ with SupportsMicroBatchRead with SupportsContinuousRead {
+
+ override def schema(): StructType = RateStreamProvider.SCHEMA
+
+ override def createMicroBatchInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): MicroBatchInputStream = {
+ new RateStreamMicroBatchInputStream(options, checkpointLocation)
}
- override def shortName(): String = "rate"
+ override def createContinuousInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): ContinuousInputStream = {
+ new RateStreamContinuousInputStream(options)
+ }
}
object RateStreamProvider {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchInputStream.scala
similarity index 62%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchInputStream.scala
index b2a573eae504..aa8c4f430a4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchInputStream.scala
@@ -19,41 +19,29 @@ package org.apache.spark.sql.execution.streaming.sources
import java.io.{BufferedReader, InputStreamReader, IOException}
import java.net.Socket
-import java.text.SimpleDateFormat
-import java.util.{Calendar, Locale}
+import java.util.Calendar
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ListBuffer
-import scala.util.{Failure, Success, Try}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql._
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder}
-import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport
-import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider}
+import org.apache.spark.sql.execution.streaming.LongOffset
+import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset}
-import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, MicroBatchScan, Offset}
import org.apache.spark.unsafe.types.UTF8String
-object TextSocketReader {
- val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
- val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
- StructField("timestamp", TimestampType) :: Nil)
- val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
-}
-
/**
- * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials
- * and debugging. This MicroBatchReadSupport will *not* work in production applications due to
+ * A MicroBatchInputStream that reads text lines through a TCP socket, designed only for tutorials
+ * and debugging. This MicroBatchInputStream will *not* work in production applications due to
* multiple reasons, including no support for fault recovery.
*/
-class TextSocketMicroBatchReadSupport(options: DataSourceOptions)
- extends MicroBatchReadSupport with Logging {
+class TextSocketMicroBatchInputStream(options: DataSourceOptions)
+ extends MicroBatchInputStream with Logging {
private val host: String = options.get("host").get()
private val port: Int = options.get("port").get().toInt
@@ -99,7 +87,7 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions)
logWarning(s"Stream closed by $host:$port")
return
}
- TextSocketMicroBatchReadSupport.this.synchronized {
+ TextSocketMicroBatchInputStream.this.synchronized {
val newData = (
UTF8String.fromString(line),
DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis)
@@ -124,22 +112,9 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions)
LongOffset(json.toLong)
}
- override def fullSchema(): StructType = {
- if (options.getBoolean("includeTimestamp", false)) {
- TextSocketReader.SCHEMA_TIMESTAMP
- } else {
- TextSocketReader.SCHEMA_REGULAR
- }
- }
-
- override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = {
- new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end))
- }
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val sc = config.asInstanceOf[SimpleStreamingScanConfig]
- val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1
- val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1
+ override def createMicroBatchScan(start: Offset, end: Offset): MicroBatchScan = {
+ val startOrdinal = start.asInstanceOf[LongOffset].offset.toInt + 1
+ val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
// Internal buffer only holds the batches after lastOffsetCommitted
val rawList = synchronized {
@@ -161,29 +136,7 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions)
slices(idx % numPartitions).append(r)
}
- slices.map(TextSocketInputPartition)
- }
-
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- new PartitionReaderFactory {
- override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
- val slice = partition.asInstanceOf[TextSocketInputPartition].slice
- new PartitionReader[InternalRow] {
- private var currentIdx = -1
-
- override def next(): Boolean = {
- currentIdx += 1
- currentIdx < slice.size
- }
-
- override def get(): InternalRow = {
- InternalRow(slice(currentIdx)._1, slice(currentIdx)._2)
- }
-
- override def close(): Unit = {}
- }
- }
- }
+ new TextSocketMicroBatchScan(slices.map(TextSocketInputPartition))
}
override def commit(end: Offset): Unit = synchronized {
@@ -219,44 +172,33 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions)
override def toString: String = s"TextSocketV2[host: $host, port: $port]"
}
-case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition
+class TextSocketMicroBatchScan(partitions: Array[InputPartition]) extends MicroBatchScan {
-class TextSocketSourceProvider extends DataSourceV2
- with MicroBatchReadSupportProvider with ContinuousReadSupportProvider
- with DataSourceRegister with Logging {
+ override def createReaderFactory(): PartitionReaderFactory = {
+ new PartitionReaderFactory {
+ override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
+ val slice = partition.asInstanceOf[TextSocketInputPartition].slice
+ new PartitionReader[InternalRow] {
+ private var currentIdx = -1
- private def checkParameters(params: DataSourceOptions): Unit = {
- logWarning("The socket source should not be used for production applications! " +
- "It does not support recovery.")
- if (!params.get("host").isPresent) {
- throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
- }
- if (!params.get("port").isPresent) {
- throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
- }
- Try {
- params.get("includeTimestamp").orElse("false").toBoolean
- } match {
- case Success(_) =>
- case Failure(_) =>
- throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
- }
- }
+ override def next(): Boolean = {
+ currentIdx += 1
+ currentIdx < slice.size
+ }
- override def createMicroBatchReadSupport(
- checkpointLocation: String,
- options: DataSourceOptions): MicroBatchReadSupport = {
- checkParameters(options)
- new TextSocketMicroBatchReadSupport(options)
- }
+ override def get(): InternalRow = {
+ InternalRow(slice(currentIdx)._1, slice(currentIdx)._2)
+ }
- override def createContinuousReadSupport(
- checkpointLocation: String,
- options: DataSourceOptions): ContinuousReadSupport = {
- checkParameters(options)
- new TextSocketContinuousReadSupport(options)
+ override def close(): Unit = {}
+ }
+ }
+ }
}
- /** String that represents the format that this data source provider uses. */
- override def shortName(): String = "socket"
+ override def planInputPartitions(): Array[InputPartition] = {
+ partitions
+ }
}
+
+case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala
new file mode 100644
index 000000000000..e2a4f15b6752
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.text.SimpleDateFormat
+import java.util.Locale
+
+import scala.util.{Failure, Success, Try}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousInputStream
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.streaming._
+import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+
+object TextSocketReader {
+ val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
+ val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
+ StructField("timestamp", TimestampType) :: Nil)
+ val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
+}
+
+class TextSocketSourceProvider extends Format with DataSourceRegister with Logging {
+
+ override def getTable(options: DataSourceOptions): TextSocketTable = {
+ new TextSocketTable(options)
+ }
+
+ /** String that represents the format that this data source provider uses. */
+ override def shortName(): String = "socket"
+}
+
+class TextSocketTable(options: DataSourceOptions) extends Table
+ with SupportsMicroBatchRead with SupportsContinuousRead with Logging {
+
+ override def schema(): StructType = {
+ if (options.getBoolean("includeTimestamp", false)) {
+ TextSocketReader.SCHEMA_TIMESTAMP
+ } else {
+ TextSocketReader.SCHEMA_REGULAR
+ }
+ }
+
+ override def createMicroBatchInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): MicroBatchInputStream = {
+ checkParameters(options)
+ new TextSocketMicroBatchInputStream(options)
+ }
+
+ override def createContinuousInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): ContinuousInputStream = {
+ checkParameters(options)
+ new TextSocketContinuousInputStream(options)
+ }
+
+ private def checkParameters(params: DataSourceOptions): Unit = {
+ logWarning("The socket source should not be used for production applications! " +
+ "It does not support recovery.")
+ if (!params.get("host").isPresent) {
+ throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
+ }
+ if (!params.get("port").isPresent) {
+ throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
+ }
+ Try {
+ params.get("includeTimestamp").orElse("false").toBoolean
+ } match {
+ case Success(_) =>
+ case Failure(_) =>
+ throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 4c7dcedafeea..13255e0b6eed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
import org.apache.spark.sql.sources.StreamSourceProvider
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider}
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
/**
* Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems,
@@ -172,60 +170,27 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
case _ => None
}
ds match {
- case s: MicroBatchReadSupportProvider =>
- val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
- ds = s, conf = sparkSession.sessionState.conf)
- val options = sessionOptions ++ extraOptions
- val dataSourceOptions = new DataSourceOptions(options.asJava)
- var tempReadSupport: MicroBatchReadSupport = null
- val schema = try {
- val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath
- tempReadSupport = if (userSpecifiedSchema.isDefined) {
- s.createMicroBatchReadSupport(
- userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions)
- } else {
- s.createMicroBatchReadSupport(tmpCheckpointPath, dataSourceOptions)
- }
- tempReadSupport.fullSchema()
- } finally {
- // Stop tempReader to avoid side-effect thing
- if (tempReadSupport != null) {
- tempReadSupport.stop()
- tempReadSupport = null
- }
+ case f: Format =>
+ val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
+ ds = f, conf = sparkSession.sessionState.conf)
+ val options = sessionOptions ++ extraOptions
+ val dsOptions = new DataSourceOptions(options.asJava)
+ val table = userSpecifiedSchema match {
+ case Some(schema) => f.getTable(dsOptions, schema)
+ case _ => f.getTable(dsOptions)
}
- Dataset.ofRows(
- sparkSession,
- StreamingRelationV2(
- s, source, options,
- schema.toAttributes, v1Relation)(sparkSession))
- case s: ContinuousReadSupportProvider =>
- val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
- ds = s, conf = sparkSession.sessionState.conf)
- val options = sessionOptions ++ extraOptions
- val dataSourceOptions = new DataSourceOptions(options.asJava)
- var tempReadSupport: ContinuousReadSupport = null
- val schema = try {
- val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath
- tempReadSupport = if (userSpecifiedSchema.isDefined) {
- s.createContinuousReadSupport(
- userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions)
- } else {
- s.createContinuousReadSupport(tmpCheckpointPath, dataSourceOptions)
- }
- tempReadSupport.fullSchema()
- } finally {
- // Stop tempReader to avoid side-effect thing
- if (tempReadSupport != null) {
- tempReadSupport.stop()
- tempReadSupport = null
- }
+
+ table match {
+ case _: SupportsMicroBatchRead =>
+ case _: SupportsContinuousRead =>
+ case _ => return Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource))
}
+
Dataset.ofRows(
sparkSession,
StreamingRelationV2(
- s, source, options,
- schema.toAttributes, v1Relation)(sparkSession))
+ source, f, table, options,
+ table.schema().toAttributes, v1Relation)(sparkSession))
case _ =>
// Code path for data source v1.
Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource))
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
index 5602310219a7..c47ef696f62c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
@@ -24,27 +24,42 @@
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
-import org.apache.spark.sql.sources.v2.BatchReadSupportProvider;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
-import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.*;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.types.StructType;
-public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider {
+public class JavaAdvancedDataSourceV2 implements Format {
+
+ class MyTable implements Table, SupportsBatchRead {
- public class ReadSupport extends JavaSimpleReadSupport {
@Override
- public ScanConfigBuilder newScanConfigBuilder() {
+ public ScanConfigBuilder newScanConfigBuilder(DataSourceOptions options) {
return new AdvancedScanConfigBuilder();
}
@Override
- public InputPartition[] planInputPartitions(ScanConfig config) {
- Filter[] filters = ((AdvancedScanConfigBuilder) config).filters;
- List res = new ArrayList<>();
+ public BatchScan createBatchScan(ScanConfig config, DataSourceOptions options) {
+ return new AdvancedBatchScan((AdvancedScanConfigBuilder) config);
+ }
+
+ @Override
+ public StructType schema() {
+ return new StructType().add("i", "int").add("j", "int");
+ }
+ }
+ public static class AdvancedBatchScan implements BatchScan {
+ public AdvancedScanConfigBuilder config;
+
+ AdvancedBatchScan(AdvancedScanConfigBuilder config) {
+ this.config = config;
+ }
+
+ @Override
+ public InputPartition[] planInputPartitions() {
+ List res = new ArrayList<>();
Integer lowerBound = null;
- for (Filter filter : filters) {
+ for (Filter filter : config.filters) {
if (filter instanceof GreaterThan) {
GreaterThan f = (GreaterThan) filter;
if ("i".equals(f.attribute()) && f.value() instanceof Integer) {
@@ -68,12 +83,12 @@ public InputPartition[] planInputPartitions(ScanConfig config) {
}
@Override
- public PartitionReaderFactory createReaderFactory(ScanConfig config) {
- StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema;
- return new AdvancedReaderFactory(requiredSchema);
+ public PartitionReaderFactory createReaderFactory() {
+ return new AdvancedReaderFactory(config.requiredSchema);
}
}
+
public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig,
SupportsPushDownFilters, SupportsPushDownRequiredColumns {
@@ -166,9 +181,8 @@ public void close() throws IOException {
}
}
-
@Override
- public BatchReadSupport createBatchReadSupport(DataSourceOptions options) {
- return new ReadSupport();
+ public Table getTable(DataSourceOptions options) {
+ return new MyTable();
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java
index 28a933039831..df4f0c676591 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java
@@ -21,21 +21,18 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
-import org.apache.spark.sql.sources.v2.BatchReadSupportProvider;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
-import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.*;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
+public class JavaColumnarDataSourceV2 implements Format {
-public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider {
-
- class ReadSupport extends JavaSimpleReadSupport {
+ class MyTable extends SimpleBatchReadTable {
@Override
- public InputPartition[] planInputPartitions(ScanConfig config) {
+ public InputPartition[] planInputPartitions() {
InputPartition[] partitions = new InputPartition[2];
partitions[0] = new JavaRangeInputPartition(0, 50);
partitions[1] = new JavaRangeInputPartition(50, 90);
@@ -43,7 +40,7 @@ public InputPartition[] planInputPartitions(ScanConfig config) {
}
@Override
- public PartitionReaderFactory createReaderFactory(ScanConfig config) {
+ public PartitionReaderFactory createReaderFactory() {
return new ColumnarReaderFactory();
}
}
@@ -108,7 +105,7 @@ public void close() throws IOException {
}
@Override
- public BatchReadSupport createBatchReadSupport(DataSourceOptions options) {
- return new ReadSupport();
+ public Table getTable(DataSourceOptions options) {
+ return new MyTable();
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
index 18a11dde8219..560c54ac1c84 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
@@ -28,12 +28,17 @@
import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution;
import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
-public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider {
+public class JavaPartitionAwareDataSource implements Format {
- class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning {
+ class MyTable extends SimpleBatchReadTable implements SupportsReportPartitioning {
@Override
- public InputPartition[] planInputPartitions(ScanConfig config) {
+ public PartitionReaderFactory createReaderFactory() {
+ return new SpecificReaderFactory();
+ }
+
+ @Override
+ public InputPartition[] planInputPartitions() {
InputPartition[] partitions = new InputPartition[2];
partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6});
partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2});
@@ -41,12 +46,7 @@ public InputPartition[] planInputPartitions(ScanConfig config) {
}
@Override
- public PartitionReaderFactory createReaderFactory(ScanConfig config) {
- return new SpecificReaderFactory();
- }
-
- @Override
- public Partitioning outputPartitioning(ScanConfig config) {
+ public Partitioning outputPartitioning() {
return new MyPartitioning();
}
}
@@ -108,7 +108,7 @@ public void close() throws IOException {
}
@Override
- public BatchReadSupport createBatchReadSupport(DataSourceOptions options) {
- return new ReadSupport();
+ public Table getTable(DataSourceOptions options) {
+ return new MyTable();
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
index cc9ac04a0dad..2b68f987568d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
@@ -17,39 +17,37 @@
package test.org.apache.spark.sql.sources.v2;
-import org.apache.spark.sql.sources.v2.BatchReadSupportProvider;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
-import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.*;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.types.StructType;
-public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider {
+public class JavaSchemaRequiredDataSource implements Format {
- class ReadSupport extends JavaSimpleReadSupport {
+ class MyTable extends JavaSimpleBatchReadTable {
private final StructType schema;
- ReadSupport(StructType schema) {
+ MyTable(StructType schema) {
this.schema = schema;
}
@Override
- public StructType fullSchema() {
- return schema;
+ public StructType schema() {
+ return this.schema;
}
@Override
- public InputPartition[] planInputPartitions(ScanConfig config) {
+ public InputPartition[] planInputPartitions() {
return new InputPartition[0];
}
}
@Override
- public BatchReadSupport createBatchReadSupport(DataSourceOptions options) {
+ public Table getTable(DataSourceOptions options) {
throw new IllegalArgumentException("requires a user-supplied schema");
}
@Override
- public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) {
- return new ReadSupport(schema);
+ public Table getTable(DataSourceOptions options, StructType schema) {
+ return new MyTable(schema);
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchReadTable.java
similarity index 78%
rename from sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java
rename to sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchReadTable.java
index 685f9b9747e8..bafebc8cdea0 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchReadTable.java
@@ -21,46 +21,29 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.sources.v2.SupportsBatchRead;
+import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.types.StructType;
-abstract class JavaSimpleReadSupport implements BatchReadSupport {
+abstract class JavaSimpleBatchReadTable implements SupportsBatchRead, BatchScan {
@Override
- public StructType fullSchema() {
- return new StructType().add("i", "int").add("j", "int");
+ public BatchScan createBatchScan(ScanConfig config, DataSourceOptions options) {
+ return this;
}
@Override
- public ScanConfigBuilder newScanConfigBuilder() {
- return new JavaNoopScanConfigBuilder(fullSchema());
+ public StructType schema() {
+ return new StructType().add("i", "int").add("j", "int");
}
@Override
- public PartitionReaderFactory createReaderFactory(ScanConfig config) {
+ public PartitionReaderFactory createReaderFactory() {
return new JavaSimpleReaderFactory();
}
}
-class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig {
-
- private StructType schema;
-
- JavaNoopScanConfigBuilder(StructType schema) {
- this.schema = schema;
- }
-
- @Override
- public ScanConfig build() {
- return this;
- }
-
- @Override
- public StructType readSchema() {
- return schema;
- }
-}
-
class JavaSimpleReaderFactory implements PartitionReaderFactory {
@Override
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
index 2cdbba84ec4a..23896e29c160 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
@@ -17,17 +17,14 @@
package test.org.apache.spark.sql.sources.v2;
-import org.apache.spark.sql.sources.v2.BatchReadSupportProvider;
-import org.apache.spark.sql.sources.v2.DataSourceV2;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
+import org.apache.spark.sql.sources.v2.*;
import org.apache.spark.sql.sources.v2.reader.*;
-public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider {
-
- class ReadSupport extends JavaSimpleReadSupport {
+public class JavaSimpleDataSourceV2 implements Format {
+ class MyTable extends JavaSimpleBatchReadTable {
@Override
- public InputPartition[] planInputPartitions(ScanConfig config) {
+ public InputPartition[] planInputPartitions() {
InputPartition[] partitions = new InputPartition[2];
partitions[0] = new JavaRangeInputPartition(0, 5);
partitions[1] = new JavaRangeInputPartition(5, 10);
@@ -36,7 +33,7 @@ public InputPartition[] planInputPartitions(ScanConfig config) {
}
@Override
- public BatchReadSupport createBatchReadSupport(DataSourceOptions options) {
- return new ReadSupport();
+ public Table getTable(DataSourceOptions options) {
+ return new MyTable();
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
index dd74af873c2e..4bb467350467 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
@@ -25,15 +25,16 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, Format, SupportsMicroBatchRead}
import org.apache.spark.sql.sources.v2.reader.streaming.Offset
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.util.ManualClock
-class RateSourceSuite extends StreamTest {
+class RateStreamProviderSuite extends StreamTest {
import testImplicits._
@@ -41,7 +42,9 @@ class RateSourceSuite extends StreamTest {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
assert(query.nonEmpty)
val rateSource = query.get.logicalPlan.collect {
- case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source
+ case r: StreamingDataSourceV2Relation
+ if r.stream.isInstanceOf[RateStreamMicroBatchInputStream] =>
+ r.stream.asInstanceOf[RateStreamMicroBatchInputStream]
}.head
rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
@@ -51,27 +54,16 @@ class RateSourceSuite extends StreamTest {
}
}
- test("microbatch in registry") {
- withTempDir { temp =>
- DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
- case ds: MicroBatchReadSupportProvider =>
- val readSupport = ds.createMicroBatchReadSupport(
- temp.getCanonicalPath, DataSourceOptions.empty())
- assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport])
- case _ =>
- throw new IllegalStateException("Could not find read support for rate")
- }
- }
+ test("RateStreamProvider in registry") {
+ val ds = DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance()
+ assert(ds.isInstanceOf[RateStreamProvider], "Could not find rate source")
}
test("compatible with old path in registry") {
- DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider",
- spark.sqlContext.conf).newInstance() match {
- case ds: MicroBatchReadSupportProvider =>
- assert(ds.isInstanceOf[RateStreamProvider])
- case _ =>
- throw new IllegalStateException("Could not find read support for rate")
- }
+ val ds = DataSource.lookupDataSource(
+ "org.apache.spark.sql.execution.streaming.RateSourceProvider",
+ spark.sqlContext.conf).newInstance()
+ assert(ds.isInstanceOf[RateStreamProvider], "Could not find rate source")
}
test("microbatch - basic") {
@@ -141,17 +133,17 @@ class RateSourceSuite extends StreamTest {
test("microbatch - infer offsets") {
withTempDir { temp =>
- val readSupport = new RateStreamMicroBatchReadSupport(
+ val stream = new RateStreamMicroBatchInputStream(
new DataSourceOptions(
Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava),
temp.getCanonicalPath)
- readSupport.clock.asInstanceOf[ManualClock].advance(100000)
- val startOffset = readSupport.initialOffset()
+ stream.clock.asInstanceOf[ManualClock].advance(100000)
+ val startOffset = stream.initialOffset()
startOffset match {
case r: LongOffset => assert(r.offset === 0L)
case _ => throw new IllegalStateException("unexpected offset type")
}
- readSupport.latestOffset() match {
+ stream.latestOffset() match {
case r: LongOffset => assert(r.offset >= 100)
case _ => throw new IllegalStateException("unexpected offset type")
}
@@ -160,16 +152,14 @@ class RateSourceSuite extends StreamTest {
test("microbatch - predetermined batch size") {
withTempDir { temp =>
- val readSupport = new RateStreamMicroBatchReadSupport(
+ val stream = new RateStreamMicroBatchInputStream(
new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava),
temp.getCanonicalPath)
- val startOffset = LongOffset(0L)
- val endOffset = LongOffset(1L)
- val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build()
- val tasks = readSupport.planInputPartitions(config)
- val readerFactory = readSupport.createReaderFactory(config)
- assert(tasks.size == 1)
- val dataReader = readerFactory.createReader(tasks(0))
+ val scan = stream.createMicroBatchScan(LongOffset(0L), LongOffset(1L))
+ val partitions = scan.planInputPartitions()
+ val readerFactory = scan.createReaderFactory()
+ assert(partitions.size == 1)
+ val dataReader = readerFactory.createReader(partitions(0))
val data = ArrayBuffer[InternalRow]()
while (dataReader.next()) {
data.append(dataReader.get())
@@ -180,17 +170,15 @@ class RateSourceSuite extends StreamTest {
test("microbatch - data read") {
withTempDir { temp =>
- val readSupport = new RateStreamMicroBatchReadSupport(
+ val stream = new RateStreamMicroBatchInputStream(
new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava),
temp.getCanonicalPath)
- val startOffset = LongOffset(0L)
- val endOffset = LongOffset(1L)
- val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build()
- val tasks = readSupport.planInputPartitions(config)
- val readerFactory = readSupport.createReaderFactory(config)
- assert(tasks.size == 11)
-
- val readData = tasks
+ val scan = stream.createMicroBatchScan(LongOffset(0L), LongOffset(1L))
+ val partitions = scan.planInputPartitions()
+ val readerFactory = scan.createReaderFactory()
+ assert(partitions.size == 11)
+
+ val readData = partitions
.map(readerFactory.createReader)
.flatMap { reader =>
val buf = scala.collection.mutable.ListBuffer[InternalRow]()
@@ -319,29 +307,18 @@ class RateSourceSuite extends StreamTest {
"rate source does not support user-specified schema"))
}
- test("continuous in registry") {
- DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
- case ds: ContinuousReadSupportProvider =>
- val readSupport = ds.createContinuousReadSupport(
- "", DataSourceOptions.empty())
- assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport])
- case _ =>
- throw new IllegalStateException("Could not find read support for continuous rate")
- }
- }
-
test("continuous data") {
- val readSupport = new RateStreamContinuousReadSupport(
+ val stream = new RateStreamContinuousInputStream(
new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
- val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build()
- val tasks = readSupport.planInputPartitions(config)
- val readerFactory = readSupport.createContinuousReaderFactory(config)
- assert(tasks.size == 2)
+ val scan = stream.createContinuousScan(stream.initialOffset)
+ val partitions = scan.planInputPartitions()
+ val readerFactory = scan.createContinuousReaderFactory()
+ assert(partitions.size == 2)
val data = scala.collection.mutable.ListBuffer[InternalRow]()
- tasks.foreach {
+ partitions.foreach {
case t: RateStreamContinuousInputPartition =>
- val startTimeMs = readSupport.initialOffset()
+ val startTimeMs = stream.initialOffset()
.asInstanceOf[RateStreamOffset]
.partitionToValueAndRunTimeMs(t.partitionIndex)
.runTimeMs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
index 409156e5ebc7..760c6f367d40 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
@@ -30,10 +30,11 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsMicroBatchRead}
import org.apache.spark.sql.sources.v2.reader.streaming.Offset
import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
import org.apache.spark.sql.test.SharedSQLContext
@@ -59,7 +60,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
"Cannot add data when there is no query for finding the active socket source")
val sources = query.get.logicalPlan.collect {
- case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source
+ case r: StreamingDataSourceV2Relation
+ if r.stream.isInstanceOf[TextSocketMicroBatchInputStream] =>
+ r.stream.asInstanceOf[TextSocketMicroBatchInputStream]
}
if (sources.isEmpty) {
throw new Exception(
@@ -83,13 +86,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
}
test("backward compatibility with old path") {
- DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider",
- spark.sqlContext.conf).newInstance() match {
- case ds: MicroBatchReadSupportProvider =>
- assert(ds.isInstanceOf[TextSocketSourceProvider])
- case _ =>
- throw new IllegalStateException("Could not find socket source")
- }
+ val ds = DataSource.lookupDataSource(
+ "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider",
+ spark.sqlContext.conf).newInstance()
+ assert(ds.isInstanceOf[TextSocketSourceProvider], "Could not find socket source")
}
test("basic usage") {
@@ -173,39 +173,37 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
}
test("params not given") {
- val provider = new TextSocketSourceProvider
+ val table = new TextSocketSourceProvider().getTable(DataSourceOptions.empty())
intercept[AnalysisException] {
- provider.createMicroBatchReadSupport(
- "", new DataSourceOptions(Map.empty[String, String].asJava))
+ table.createMicroBatchInputStream(
+ "", null, new DataSourceOptions(Map.empty[String, String].asJava))
}
intercept[AnalysisException] {
- provider.createMicroBatchReadSupport(
- "", new DataSourceOptions(Map("host" -> "localhost").asJava))
+ table.createMicroBatchInputStream(
+ "", null, new DataSourceOptions(Map("host" -> "localhost").asJava))
}
intercept[AnalysisException] {
- provider.createMicroBatchReadSupport(
- "", new DataSourceOptions(Map("port" -> "1234").asJava))
+ table.createMicroBatchInputStream(
+ "", null, new DataSourceOptions(Map("port" -> "1234").asJava))
}
}
test("non-boolean includeTimestamp") {
- val provider = new TextSocketSourceProvider
+ val table = new TextSocketSourceProvider().getTable(DataSourceOptions.empty())
val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle")
intercept[AnalysisException] {
val a = new DataSourceOptions(params.asJava)
- provider.createMicroBatchReadSupport("", a)
+ table.createMicroBatchInputStream("", null, a)
}
}
test("user-specified schema given") {
- val provider = new TextSocketSourceProvider
+ val provider = new TextSocketSourceProvider()
val userSpecifiedSchema = StructType(
StructField("name", StringType) ::
StructField("area", StringType) :: Nil)
- val params = Map("host" -> "localhost", "port" -> "1234")
val exception = intercept[UnsupportedOperationException] {
- provider.createMicroBatchReadSupport(
- userSpecifiedSchema, "", new DataSourceOptions(params.asJava))
+ provider.getTable(DataSourceOptions.empty(), userSpecifiedSchema)
}
assert(exception.getMessage.contains(
"socket source does not support user-specified schema"))
@@ -299,25 +297,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
serverThread = new ServerThread()
serverThread.start()
- val readSupport = new TextSocketContinuousReadSupport(
+ val stream = new TextSocketContinuousInputStream(
new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
"port" -> serverThread.port.toString).asJava))
-
- val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build()
- val tasks = readSupport.planInputPartitions(scanConfig)
- assert(tasks.size == 2)
+ val scan = stream.createContinuousScan(stream.initialOffset())
+ val partitions = scan.planInputPartitions()
+ assert(partitions.size == 2)
val numRecords = 10
val data = scala.collection.mutable.ListBuffer[Int]()
val offsets = scala.collection.mutable.ListBuffer[Int]()
- val readerFactory = readSupport.createContinuousReaderFactory(scanConfig)
+ val readerFactory = scan.createContinuousReaderFactory()
import org.scalatest.time.SpanSugar._
failAfter(5 seconds) {
// inject rows, read and check the data and offsets
for (i <- 0 until numRecords) {
serverThread.enqueue(i.toString)
}
- tasks.foreach {
+ partitions.foreach {
case t: TextSocketContinuousInputPartition =>
val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader]
for (i <- 0 until numRecords / 2) {
@@ -335,15 +332,15 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
data.clear()
case _ => throw new IllegalStateException("Unexpected task type")
}
- assert(readSupport.startOffset.offsets == List(3, 3))
- readSupport.commit(TextSocketOffset(List(5, 5)))
- assert(readSupport.startOffset.offsets == List(5, 5))
+ assert(stream.startOffset.offsets == List(3, 3))
+ stream.commit(TextSocketOffset(List(5, 5)))
+ assert(stream.startOffset.offsets == List(5, 5))
}
def commitOffset(partition: Int, offset: Int): Unit = {
- val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset)
- readSupport.commit(TextSocketOffset(offsetsToCommit))
- assert(readSupport.startOffset.offsets == offsetsToCommit)
+ val offsetsToCommit = stream.startOffset.offsets.updated(partition, offset)
+ stream.commit(TextSocketOffset(offsetsToCommit))
+ assert(stream.startOffset.offsets == offsetsToCommit)
}
}
@@ -351,13 +348,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
serverThread = new ServerThread()
serverThread.start()
- val readSupport = new TextSocketContinuousReadSupport(
+ val stream = new TextSocketContinuousInputStream(
new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
"port" -> serverThread.port.toString).asJava))
- readSupport.startOffset = TextSocketOffset(List(5, 5))
+ stream.startOffset = TextSocketOffset(List(5, 5))
assertThrows[IllegalStateException] {
- readSupport.commit(TextSocketOffset(List(6, 6)))
+ stream.commit(TextSocketOffset(List(6, 6)))
}
}
@@ -365,21 +362,21 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
serverThread = new ServerThread()
serverThread.start()
- val readSupport = new TextSocketContinuousReadSupport(
+ val stream = new TextSocketContinuousInputStream(
new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
"includeTimestamp" -> "true",
"port" -> serverThread.port.toString).asJava))
- val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build()
- val tasks = readSupport.planInputPartitions(scanConfig)
- assert(tasks.size == 2)
+ val scan = stream.createContinuousScan(stream.initialOffset())
+ val partitions = scan.planInputPartitions()
+ assert(partitions.size == 2)
val numRecords = 4
// inject rows, read and check the data and offsets
for (i <- 0 until numRecords) {
serverThread.enqueue(i.toString)
}
- val readerFactory = readSupport.createContinuousReaderFactory(scanConfig)
- tasks.foreach {
+ val readerFactory = scan.createContinuousReaderFactory()
+ partitions.foreach {
case t: TextSocketContinuousInputPartition =>
val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader]
for (i <- 0 until numRecords / 2) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index e8f291af13ba..c4607086afdf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -41,7 +41,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = {
query.queryExecution.executedPlan.collect {
case d: DataSourceV2ScanExec =>
- d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder]
+ d.scan.asInstanceOf[AdvancedBatchScan].config
}.head
}
@@ -49,7 +49,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = {
query.queryExecution.executedPlan.collect {
case d: DataSourceV2ScanExec =>
- d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder]
+ d.scan.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedBatchScan].config
}.head
}
@@ -374,10 +374,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
case class RangeInputPartition(start: Int, end: Int) extends InputPartition
-case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig {
- override def build(): ScanConfig = this
-}
-
object SimpleReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val RangeInputPartition(start, end) = partition
@@ -396,83 +392,54 @@ object SimpleReaderFactory extends PartitionReaderFactory {
}
}
-abstract class SimpleReadSupport extends BatchReadSupport {
- override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int")
+abstract class SimpleBatchReadTable extends Table with SupportsBatchRead with BatchScan {
- override def newScanConfigBuilder(): ScanConfigBuilder = {
- NoopScanConfigBuilder(fullSchema())
- }
+ override def createBatchScan(config: ScanConfig, options: DataSourceOptions): BatchScan = this
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- SimpleReaderFactory
- }
+ override def schema(): StructType = new StructType().add("i", "int").add("j", "int")
+
+ override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory
}
-class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider {
+class SimpleSinglePartitionSource extends Format {
- class ReadSupport extends SimpleReadSupport {
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ class MyTable extends SimpleBatchReadTable {
+ override def planInputPartitions(): Array[InputPartition] = {
Array(RangeInputPartition(0, 5))
}
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
+ override def getTable(options: DataSourceOptions): Table = new MyTable()
}
// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
// tests still pass.
-class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider {
+class SimpleDataSourceV2 extends Format {
- class ReadSupport extends SimpleReadSupport {
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ class MyTable extends SimpleBatchReadTable {
+ override def planInputPartitions(): Array[InputPartition] = {
Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10))
}
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
+ override def getTable(options: DataSourceOptions): Table = new MyTable
}
-class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider {
-
- class ReadSupport extends SimpleReadSupport {
- override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder()
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters
-
- val lowerBound = filters.collectFirst {
- case GreaterThan("i", v: Int) => v
- }
-
- val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition]
-
- if (lowerBound.isEmpty) {
- res.append(RangeInputPartition(0, 5))
- res.append(RangeInputPartition(5, 10))
- } else if (lowerBound.get < 4) {
- res.append(RangeInputPartition(lowerBound.get + 1, 5))
- res.append(RangeInputPartition(5, 10))
- } else if (lowerBound.get < 9) {
- res.append(RangeInputPartition(lowerBound.get + 1, 10))
- }
-
- res.toArray
+class AdvancedDataSourceV2 extends Format {
+ class MyTable extends SupportsBatchRead {
+ override def newScanConfigBuilder(options: DataSourceOptions): ScanConfigBuilder = {
+ new AdvancedScanConfigBuilder()
}
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema
- new AdvancedReaderFactory(requiredSchema)
+ override def createBatchScan(config: ScanConfig, options: DataSourceOptions): BatchScan = {
+ new AdvancedBatchScan(config.asInstanceOf[AdvancedScanConfigBuilder])
}
- }
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
+ override def schema(): StructType = new StructType().add("i", "int").add("j", "int")
}
+
+ override def getTable(options: DataSourceOptions): Table = new MyTable
}
class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig
@@ -501,6 +468,33 @@ class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig
override def build(): ScanConfig = this
}
+class AdvancedBatchScan(val config: AdvancedScanConfigBuilder) extends BatchScan {
+
+ override def createReaderFactory(): PartitionReaderFactory = {
+ new AdvancedReaderFactory(config.requiredSchema)
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ val lowerBound = config.filters.collectFirst {
+ case GreaterThan("i", v: Int) => v
+ }
+
+ val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition]
+
+ if (lowerBound.isEmpty) {
+ res.append(RangeInputPartition(0, 5))
+ res.append(RangeInputPartition(5, 10))
+ } else if (lowerBound.get < 4) {
+ res.append(RangeInputPartition(lowerBound.get + 1, 5))
+ res.append(RangeInputPartition(5, 10))
+ } else if (lowerBound.get < 9) {
+ res.append(RangeInputPartition(lowerBound.get + 1, 10))
+ }
+
+ res.toArray
+ }
+}
+
class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val RangeInputPartition(start, end) = partition
@@ -526,40 +520,30 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF
}
-class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider {
+class SchemaRequiredDataSource extends Format {
- class ReadSupport(val schema: StructType) extends SimpleReadSupport {
- override def fullSchema(): StructType = schema
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] =
- Array.empty
+ class MyTable(override val schema: StructType) extends SimpleBatchReadTable {
+ override def planInputPartitions(): Array[InputPartition] = Array.empty
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
+ override def getTable(options: DataSourceOptions): Table = {
throw new IllegalArgumentException("requires a user-supplied schema")
}
- override def createBatchReadSupport(
- schema: StructType, options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport(schema)
- }
+ override def getTable(options: DataSourceOptions, schema: StructType): Table = new MyTable(schema)
}
-class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider {
+class ColumnarDataSourceV2 extends Format {
- class ReadSupport extends SimpleReadSupport {
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ class MyTable extends SimpleBatchReadTable {
+ override def planInputPartitions(): Array[InputPartition] = {
Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90))
}
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- ColumnarReaderFactory
- }
+ override def createReaderFactory(): PartitionReaderFactory = ColumnarReaderFactory
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
+ override def getTable(options: DataSourceOptions): Table = new MyTable
}
object ColumnarReaderFactory extends PartitionReaderFactory {
@@ -608,21 +592,20 @@ object ColumnarReaderFactory extends PartitionReaderFactory {
}
-class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider {
+class PartitionAwareDataSource extends Format {
+
+ class MyTable extends SimpleBatchReadTable with SupportsReportPartitioning {
- class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning {
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ override def createReaderFactory(): PartitionReaderFactory = SpecificReaderFactory
+
+ override def planInputPartitions(): Array[InputPartition] = {
// Note that we don't have same value of column `a` across partitions.
Array(
SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)),
SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2)))
}
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- SpecificReaderFactory
- }
-
- override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning
+ override def outputPartitioning(): Partitioning = new MyPartitioning
}
class MyPartitioning extends Partitioning {
@@ -634,9 +617,7 @@ class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvide
}
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
+ override def getTable(options: DataSourceOptions): Table = new MyTable
}
case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition
@@ -662,7 +643,7 @@ object SpecificReaderFactory extends PartitionReaderFactory {
class SchemaReadAttemptException(m: String) extends RuntimeException(m)
class SimpleWriteOnlyDataSource extends SimpleWritableDataSource {
- override def fullSchema(): StructType = {
+ override def schema(): StructType = {
// This is a bit hacky since this source implements read support but throws
// during schema retrieval. Might have to rewrite but it's done
// such so for minimised changes.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
index a7dfc2d1deac..3b6ca89630fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.writer._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
/**
@@ -39,19 +39,17 @@ import org.apache.spark.util.SerializableConfiguration
* Each job moves files from `target/_temporary/queryId/` to `target`.
*/
class SimpleWritableDataSource extends DataSourceV2
- with BatchReadSupportProvider
- with BatchWriteSupportProvider
- with SessionConfigSupport {
+ with Format with BatchWriteSupportProvider with SessionConfigSupport {
- protected def fullSchema(): StructType = new StructType().add("i", "long").add("j", "long")
+ protected def schema(): StructType = new StructType().add("i", "long").add("j", "long")
override def keyPrefix: String = "simpleWritableDataSource"
- class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport {
+ class MyTable(path: String, conf: Configuration) extends SimpleBatchReadTable {
- override def fullSchema(): StructType = SimpleWritableDataSource.this.fullSchema()
+ override def schema(): StructType = SimpleWritableDataSource.this.schema()
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ override def planInputPartitions(): Array[InputPartition] = {
val dataPath = new Path(path)
val fs = dataPath.getFileSystem(conf)
if (fs.exists(dataPath)) {
@@ -66,7 +64,7 @@ class SimpleWritableDataSource extends DataSourceV2
}
}
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
+ override def createReaderFactory(): PartitionReaderFactory = {
val serializableConf = new SerializableConfiguration(conf)
new CSVReaderFactory(serializableConf)
}
@@ -105,10 +103,10 @@ class SimpleWritableDataSource extends DataSourceV2
}
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
+ override def getTable(options: DataSourceOptions): Table = {
val path = new Path(options.get("path").get())
val conf = SparkContext.getActive.get.hadoopConfiguration
- new ReadSupport(path.toUri.toString, conf)
+ new MyTable(path.toUri.toString, conf)
}
override def createBatchWriteSupport(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index f55ddb5419d2..406d29474776 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -30,20 +30,27 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.time.SpanSugar._
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
+import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.plans.logical.Range
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
+import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchInputStream, MicroBatchScan, Offset => OffsetV2}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
class StreamSuite extends StreamTest {
@@ -102,12 +109,10 @@ class StreamSuite extends StreamTest {
}
test("StreamingExecutionRelation.computeStats") {
- val streamingExecutionRelation = MemoryStream[Int].toDF.logicalPlan collect {
- case s: StreamingExecutionRelation => s
- }
- assert(streamingExecutionRelation.nonEmpty, "cannot find StreamingExecutionRelation")
- assert(streamingExecutionRelation.head.computeStats.sizeInBytes
- == spark.sessionState.conf.defaultSizeInBytes)
+ val memoryStream = MemoryStream[Int]
+ val executionRelation = StreamingExecutionRelation(
+ memoryStream, memoryStream.encoder.schema.toAttributes)(memoryStream.sqlContext.sparkSession)
+ assert(executionRelation.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes)
}
test("explain join with a normal source") {
@@ -154,21 +159,25 @@ class StreamSuite extends StreamTest {
}
test("SPARK-20432: union one stream with itself") {
- val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a")
- val unioned = df.union(df)
- withTempDir { outputDir =>
- withTempDir { checkpointDir =>
- val query =
- unioned
- .writeStream.format("parquet")
- .option("checkpointLocation", checkpointDir.getAbsolutePath)
- .start(outputDir.getAbsolutePath)
- try {
- query.processAllAvailable()
- val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long]
- checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*)
- } finally {
- query.stop()
+ val v1Source = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a")
+ val v2Source = spark.readStream.format(classOf[FakeFormat].getName).load().select("a")
+
+ Seq(v1Source, v2Source).foreach { df =>
+ val unioned = df.union(df)
+ withTempDir { outputDir =>
+ withTempDir { checkpointDir =>
+ val query =
+ unioned
+ .writeStream.format("parquet")
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .start(outputDir.getAbsolutePath)
+ try {
+ query.processAllAvailable()
+ val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long]
+ checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*)
+ } finally {
+ query.stop()
+ }
}
}
}
@@ -381,7 +390,7 @@ class StreamSuite extends StreamTest {
test("insert an extraStrategy") {
try {
- spark.experimental.extraStrategies = TestStrategy :: Nil
+ spark.experimental.extraStrategies = CustomStrategy :: Nil
val inputData = MemoryStream[(String, Int)]
val df = inputData.toDS().map(_._1).toDF("a")
@@ -495,9 +504,9 @@ class StreamSuite extends StreamTest {
val explainWithoutExtended = q.explainInternal(false)
// `extended = false` only displays the physical plan.
- assert("Streaming RelationV2 MemoryStreamDataSource".r
+ assert("Streaming RelationV2 MemoryStreamSource".r
.findAllMatchIn(explainWithoutExtended).size === 0)
- assert("ScanV2 MemoryStreamDataSource".r
+ assert("ScanV2 MemoryStreamSource".r
.findAllMatchIn(explainWithoutExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithoutExtended.contains("StateStoreRestore"))
@@ -505,9 +514,9 @@ class StreamSuite extends StreamTest {
val explainWithExtended = q.explainInternal(true)
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
// plan.
- assert("Streaming RelationV2 MemoryStreamDataSource".r
+ assert("Streaming RelationV2 MemoryStreamSource".r
.findAllMatchIn(explainWithExtended).size === 3)
- assert("ScanV2 MemoryStreamDataSource".r
+ assert("ScanV2 MemoryStreamSource".r
.findAllMatchIn(explainWithExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithExtended.contains("StateStoreRestore"))
@@ -550,17 +559,17 @@ class StreamSuite extends StreamTest {
val explainWithoutExtended = q.explainInternal(false)
// `extended = false` only displays the physical plan.
- assert("Streaming RelationV2 ContinuousMemoryStream".r
+ assert("Streaming RelationV2 MemoryStreamSource".r
.findAllMatchIn(explainWithoutExtended).size === 0)
- assert("ScanV2 ContinuousMemoryStream".r
+ assert("ScanV2 MemoryStreamSource".r
.findAllMatchIn(explainWithoutExtended).size === 1)
val explainWithExtended = q.explainInternal(true)
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
// plan.
- assert("Streaming RelationV2 ContinuousMemoryStream".r
+ assert("Streaming RelationV2 MemoryStreamSource".r
.findAllMatchIn(explainWithExtended).size === 3)
- assert("ScanV2 ContinuousMemoryStream".r
+ assert("ScanV2 MemoryStreamSource".r
.findAllMatchIn(explainWithExtended).size === 1)
} finally {
q.stop()
@@ -1137,6 +1146,67 @@ class FakeDefaultSource extends FakeSource {
}
}
+// Similar to `FakeDefaultSource`, but with v2 source API.
+class FakeFormat extends Format {
+ override def getTable(options: DataSourceOptions): Table = {
+ new SupportsMicroBatchRead {
+ override def createMicroBatchInputStream(
+ checkpointLocation: String,
+ config: ScanConfig,
+ options: DataSourceOptions): MicroBatchInputStream = {
+ FakeMicroBatchInputStream
+ }
+
+ override def schema(): StructType = StructType(StructField("a", IntegerType) :: Nil)
+ }
+ }
+
+ object FakeMicroBatchInputStream extends MicroBatchInputStream {
+ override def createMicroBatchScan(start: OffsetV2, end: OffsetV2): MicroBatchScan = {
+ val s = start.asInstanceOf[LongOffset].offset.toInt
+ val e = end.asInstanceOf[LongOffset].offset.toInt
+ new FakeMicroBatchReadSupport(s, e)
+ }
+
+ override def latestOffset(): OffsetV2 = LongOffset(10)
+
+ override def initialOffset(): OffsetV2 = LongOffset(0)
+
+ override def deserializeOffset(json: String): OffsetV2 = {
+ LongOffset(json.toLong)
+ }
+
+ override def commit(end: OffsetV2): Unit = {}
+
+ override def stop(): Unit = {}
+ }
+
+ class FakeMicroBatchReadSupport(start: Int, end: Int) extends MicroBatchScan {
+ override def createReaderFactory(): PartitionReaderFactory = {
+ new PartitionReaderFactory {
+ override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
+ val RangeInputPartition(start, end) = partition
+ new PartitionReader[InternalRow] {
+ var current = start - 1
+ override def next(): Boolean = {
+ current += 1
+ current <= end
+ }
+
+ override def get(): InternalRow = InternalRow(current)
+
+ override def close(): Unit = {}
+ }
+ }
+ }
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ Array(RangeInputPartition(start, end))
+ }
+ }
+}
+
/** A fake source that throws the same IOException like pre Hadoop 2.8 when it's interrupted. */
class ThrowingIOExceptionLikeHadoop12074 extends FakeSource {
import ThrowingIOExceptionLikeHadoop12074._
@@ -1244,3 +1314,23 @@ object ThrowingExceptionInCreateSource {
@volatile var createSourceLatch: CountDownLatch = null
@volatile var exception: Exception = null
}
+
+object CustomStrategy extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case Project(Seq(attr), child) if attr.name == "a" =>
+ CustomProjectExec(Seq(attr.toAttribute), planLater(child)) :: Nil
+ case _ => Nil
+ }
+}
+
+case class CustomProjectExec(output: Seq[Attribute], child: SparkPlan) extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { it =>
+ val str = UTF8String.fromString("so fast")
+ val row = new GenericInternalRow(Array[Any](str))
+ val unsafeProj = UnsafeProjection.create(schema)
+ val unsafeRow = unsafeProj(row)
+ it.map(_ => unsafeRow)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index d878c345c298..074de90d4fcc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -688,8 +688,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def findSourceIndex(plan: LogicalPlan): Option[Int] = {
plan
.collect {
+ // v1 source
case r: StreamingExecutionRelation => r.source
- case r: StreamingDataSourceV2Relation => r.readSupport
+ // v2 source
+ case r: StreamingDataSourceV2Relation => r.stream
+ // We can add data to memory stream before starting it. Then the input plan has
+ // not been processed by the streaming engine and contains `StreamingRelationV2`.
+ case r: StreamingRelationV2 if r.sourceName == "memory" =>
+ r.table.asInstanceOf[MemoryStreamTable].stream
}
.zipWithIndex
.find(_._1 == source)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
index 46eec736d402..13b8866c22b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
@@ -24,15 +24,14 @@ import scala.util.Random
import scala.util.control.NonFatal
import org.scalatest.BeforeAndAfter
-import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, Dataset}
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.BlockingSource
import org.apache.spark.util.Utils
@@ -304,8 +303,8 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter {
if (withError) {
logDebug(s"Terminating query ${queryToStop.name} with error")
queryToStop.asInstanceOf[StreamingQueryWrapper].streamingQuery.logicalPlan.collect {
- case StreamingExecutionRelation(source, _) =>
- source.asInstanceOf[MemoryStream[Int]].addData(0)
+ case r: StreamingDataSourceV2Relation =>
+ r.stream.asInstanceOf[MemoryStream[Int]].addData(0)
}
} else {
logDebug(s"Stopping query ${queryToStop.name}")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index c170641372d6..92e9186241cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -36,8 +36,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig}
-import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchScan, Offset => OffsetV2}
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock}
import org.apache.spark.sql.types.StructType
@@ -220,10 +219,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
// getBatch should take 100 ms the first time it is called
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ override def createMicroBatchScan(start: OffsetV2, end: OffsetV2): MicroBatchScan = {
synchronized {
clock.waitTillTime(1150)
- super.planInputPartitions(config)
+ super.createMicroBatchScan(start, end)
}
}
}
@@ -906,12 +905,12 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
assert(df.logicalPlan.toJSON.contains("StreamingRelationV2"))
testStream(df)(
- AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation"))
+ AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation"))
)
testStream(df, useV2Sink = true)(
StartStream(trigger = Trigger.Continuous(100)),
- AssertOnQuery(_.logicalPlan.toJSON.contains("ContinuousExecutionRelation"))
+ AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation"))
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
index d6819eacd07c..286675f68654 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.continuous._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, ContinuousPartitionReader, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
@@ -44,7 +44,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
super.beforeEach()
epochEndpoint = EpochCoordinatorRef.create(
mock[StreamingWriteSupport],
- mock[ContinuousReadSupport],
+ mock[ContinuousInputStream],
mock[ContinuousExecution],
coordinatorId,
startEpoch,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
index 3d21bc63e0cc..f54970576b13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming.continuous
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
import org.apache.spark.sql._
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
@@ -40,13 +40,15 @@ class ContinuousSuiteBase extends StreamTest {
query match {
case s: ContinuousExecution =>
assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized")
- val reader = s.lastExecution.executedPlan.collectFirst {
- case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r
+ val stream = s.lastExecution.logical.collectFirst {
+ case r: StreamingDataSourceV2Relation
+ if r.stream.isInstanceOf[RateStreamContinuousInputStream] =>
+ r.stream.asInstanceOf[RateStreamContinuousInputStream]
}.get
val deltaMs = numTriggers * 1000 + 300
- while (System.currentTimeMillis < reader.creationTime + deltaMs) {
- Thread.sleep(reader.creationTime + deltaMs - System.currentTimeMillis)
+ while (System.currentTimeMillis < stream.creationTime + deltaMs) {
+ Thread.sleep(stream.creationTime + deltaMs - System.currentTimeMillis)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
index 3c973d8ebc70..60f58082347f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark._
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.LocalSparkSession
import org.apache.spark.sql.execution.streaming.continuous._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputStream, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
import org.apache.spark.sql.test.TestSparkSession
@@ -45,7 +45,7 @@ class EpochCoordinatorSuite
private var orderVerifier: InOrder = _
override def beforeEach(): Unit = {
- val reader = mock[ContinuousReadSupport]
+ val inputStream = mock[ContinuousInputStream]
writeSupport = mock[StreamingWriteSupport]
query = mock[ContinuousExecution]
orderVerifier = inOrder(writeSupport, query)
@@ -53,7 +53,7 @@ class EpochCoordinatorSuite
spark = new TestSparkSession()
epochCoordinator
- = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get)
+ = EpochCoordinatorRef.create(writeSupport, inputStream, query, "test", 1, spark, SparkEnv.get)
}
test("single epoch") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
index 3a0e780a7391..b99dd32f5b22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
@@ -24,50 +24,49 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder}
+import org.apache.spark.sql.sources.v2.reader.ScanConfig
import org.apache.spark.sql.sources.v2.reader.streaming._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
-case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSupport {
+class FakeInputStream extends MicroBatchInputStream with ContinuousInputStream {
override def deserializeOffset(json: String): Offset = RateStreamOffset(Map())
override def commit(end: Offset): Unit = {}
override def stop(): Unit = {}
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map())
- override def fullSchema(): StructType = StructType(Seq())
- override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = null
override def initialOffset(): Offset = RateStreamOffset(Map())
override def latestOffset(): Offset = RateStreamOffset(Map())
- override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
+ override def createMicroBatchScan(start: Offset, end: Offset): MicroBatchScan = {
throw new IllegalStateException("fake source - cannot actually read")
}
- override def createContinuousReaderFactory(
- config: ScanConfig): ContinuousPartitionReaderFactory = {
- throw new IllegalStateException("fake source - cannot actually read")
- }
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ override def createContinuousScan(start: Offset): ContinuousScan = {
throw new IllegalStateException("fake source - cannot actually read")
}
}
-trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider {
- override def createMicroBatchReadSupport(
+trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead {
+ override def schema(): StructType = StructType(Seq())
+
+ override def createMicroBatchInputStream(
checkpointLocation: String,
- options: DataSourceOptions): MicroBatchReadSupport = {
+ config: ScanConfig,
+ options: DataSourceOptions): MicroBatchInputStream = {
LastReadOptions.options = options
- FakeReadSupport()
+ new FakeInputStream
}
}
-trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider {
- override def createContinuousReadSupport(
+trait FakeContinuousReadTable extends Table with SupportsContinuousRead {
+ override def schema(): StructType = StructType(Seq())
+
+ override def createContinuousInputStream(
checkpointLocation: String,
- options: DataSourceOptions): ContinuousReadSupport = {
+ config: ScanConfig,
+ options: DataSourceOptions): ContinuousInputStream = {
LastReadOptions.options = options
- FakeReadSupport()
+ new FakeInputStream
}
}
@@ -82,31 +81,43 @@ trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider {
}
}
-class FakeReadMicroBatchOnly
- extends DataSourceRegister
- with FakeMicroBatchReadSupportProvider
- with SessionConfigSupport {
+class FakeReadMicroBatchOnly extends Format with DataSourceRegister with SessionConfigSupport {
override def shortName(): String = "fake-read-microbatch-only"
override def keyPrefix: String = shortName()
+
+ override def getTable(options: DataSourceOptions): Table = {
+ new FakeMicroBatchReadTable {}
+ }
}
-class FakeReadContinuousOnly
- extends DataSourceRegister
- with FakeContinuousReadSupportProvider
- with SessionConfigSupport {
+class FakeReadContinuousOnly extends Format with DataSourceRegister with SessionConfigSupport {
override def shortName(): String = "fake-read-continuous-only"
override def keyPrefix: String = shortName()
+
+ override def getTable(options: DataSourceOptions): Table = {
+ new FakeContinuousReadTable {}
+ }
}
-class FakeReadBothModes extends DataSourceRegister
- with FakeMicroBatchReadSupportProvider with FakeContinuousReadSupportProvider {
+class FakeReadBothModes extends Format with DataSourceRegister {
override def shortName(): String = "fake-read-microbatch-continuous"
+
+ override def getTable(options: DataSourceOptions): Table = {
+ new Table
+ with FakeMicroBatchReadTable with FakeContinuousReadTable {}
+ }
}
-class FakeReadNeitherMode extends DataSourceRegister {
+class FakeReadNeitherMode extends Format with DataSourceRegister {
override def shortName(): String = "fake-read-neither-mode"
+
+ override def getTable(options: DataSourceOptions): Table = {
+ new Table {
+ override def schema(): StructType = StructType(Nil)
+ }
+ }
}
class FakeWriteSupportProvider
@@ -299,23 +310,24 @@ class StreamingDataSourceV2Suite extends StreamTest {
for ((read, write, trigger) <- cases) {
testQuietly(s"stream with read format $read, write format $write, trigger $trigger") {
- val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance()
+ val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf)
+ .newInstance().asInstanceOf[Format].getTable(DataSourceOptions.empty())
val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance()
(readSource, writeSource, trigger) match {
// Valid microbatch queries.
- case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t)
+ case (_: SupportsMicroBatchRead, _: StreamingWriteSupportProvider, t)
if !t.isInstanceOf[ContinuousTrigger] =>
testPositiveCase(read, write, trigger)
// Valid continuous queries.
- case (_: ContinuousReadSupportProvider, _: StreamingWriteSupportProvider,
+ case (_: SupportsContinuousRead, _: StreamingWriteSupportProvider,
_: ContinuousTrigger) =>
testPositiveCase(read, write, trigger)
// Invalid - can't read at all
case (r, _, _)
- if !r.isInstanceOf[MicroBatchReadSupportProvider]
- && !r.isInstanceOf[ContinuousReadSupportProvider] =>
+ if !r.isInstanceOf[SupportsMicroBatchRead]
+ && !r.isInstanceOf[SupportsContinuousRead] =>
testNegativeCase(read, write, trigger,
s"Data source $read does not support streamed reading")
@@ -326,13 +338,13 @@ class StreamingDataSourceV2Suite extends StreamTest {
// Invalid - trigger is continuous but reader is not
case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger)
- if !r.isInstanceOf[ContinuousReadSupportProvider] =>
+ if !r.isInstanceOf[SupportsContinuousRead] =>
testNegativeCase(read, write, trigger,
s"Data source $read does not support continuous processing")
// Invalid - trigger is microbatch but reader is not
case (r, _, t)
- if !r.isInstanceOf[MicroBatchReadSupportProvider] &&
+ if !r.isInstanceOf[SupportsMicroBatchRead] &&
!t.isInstanceOf[ContinuousTrigger] =>
testPostCreationNegativeCase(read, write, trigger,
s"Data source $read does not support microbatch processing")