Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -145,6 +146,7 @@ class InMemoryTable(
override def capabilities: util.Set[TableCapability] = Set(
TableCapability.BATCH_READ,
TableCapability.BATCH_WRITE,
TableCapability.STREAMING_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
TableCapability.OVERWRITE_DYNAMIC,
TableCapability.TRUNCATE).asJava
Expand All @@ -169,26 +171,35 @@ class InMemoryTable(

new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite {
private var writer: BatchWrite = Append
private var streamingWriter: StreamingWrite = StreamingAppend

override def truncate(): WriteBuilder = {
assert(writer == Append)
writer = TruncateAndAppend
streamingWriter = StreamingTruncateAndAppend
this
}

override def overwrite(filters: Array[Filter]): WriteBuilder = {
assert(writer == Append)
writer = new Overwrite(filters)
streamingWriter = new StreamingNotSupportedOperation(s"overwrite ($filters)")
this
}

override def overwriteDynamicPartitions(): WriteBuilder = {
assert(writer == Append)
writer = DynamicOverwrite
streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
this
}

override def buildForBatch(): BatchWrite = writer

override def buildForStreaming(): StreamingWrite = streamingWriter match {
case exc: StreamingNotSupportedOperation => exc.throwsException()
case s => s
}
}
}

Expand Down Expand Up @@ -231,6 +242,45 @@ class InMemoryTable(
}
}

private abstract class TestStreamingWrite extends StreamingWrite {
def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
BufferedRowsWriterFactory
}

def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
}

private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite {
override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory =
throwsException()

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit =
throwsException()

override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit =
throwsException()

def throwsException[T](): T = throw new IllegalStateException("The operation " +
s"${operation} isn't supported for streaming query.")
}

private object StreamingAppend extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}
}

private object StreamingTruncateAndAppend extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
dataMap.clear
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}
}

override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
Expand Down Expand Up @@ -310,10 +360,17 @@ private class BufferedRowsReader(partition: BufferedRows) extends PartitionReade
override def close(): Unit = {}
}

private object BufferedRowsWriterFactory extends DataWriterFactory {
private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory {
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
new BufferWriter
}

override def createWriter(
partitionId: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new BufferWriter
}
}

private class BufferWriter extends DataWriter[InternalRow] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableProvider}
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
Expand All @@ -45,6 +45,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
*/
@Evolving
final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
import DataStreamWriter._

private val df = ds.toDF()

Expand Down Expand Up @@ -294,60 +295,75 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
@throws[TimeoutException]
def start(): StreamingQuery = startInternal(None)

/**
* Starts the execution of the streaming query, which will continually output results to the given
* table as new data arrives. The returned [[StreamingQuery]] object can be used to interact with
* the stream.
*
* @since 3.1.0
*/
@throws[TimeoutException]
def saveAsTable(tableName: String): StreamingQuery = {
this.source = SOURCE_NAME_TABLE
this.tableName = tableName
startInternal(None)
}

private def startInternal(path: Option[String]): StreamingQuery = {
if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException("Hive data source can only be used with tables, you can not " +
"write files of Hive data source directly.")
}

if (source == "memory") {
assertNotPartitioned("memory")
if (source == SOURCE_NAME_TABLE) {
assertNotPartitioned(SOURCE_NAME_TABLE)

import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val originalMultipartIdentifier = df.sparkSession.sessionState.sqlParser
.parseMultipartIdentifier(tableName)
val CatalogAndIdentifier(catalog, identifier) = originalMultipartIdentifier

// Currently we don't create a logical streaming writer node in logical plan, so cannot rely
// on analyzer to resolve it. Directly lookup only for temp view to provide clearer message.
// TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
if (df.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) {
throw new AnalysisException(s"Temporary view $tableName doesn't support streaming write")
}

val tableInstance = catalog.asTableCatalog.loadTable(identifier)

import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
val sink = tableInstance match {
case t: SupportsWrite if t.supports(STREAMING_WRITE) => t
case t => throw new AnalysisException(s"Table $tableName doesn't support streaming " +
s"write - $t")
}

startQuery(sink, extraOptions)
} else if (source == SOURCE_NAME_MEMORY) {
assertNotPartitioned(SOURCE_NAME_MEMORY)
if (extraOptions.get("queryName").isEmpty) {
throw new AnalysisException("queryName must be specified for memory sink")
}
val sink = new MemorySink()
val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes))
val chkpointLoc = extraOptions.get("checkpointLocation")
val recoverFromChkpoint = outputMode == OutputMode.Complete()
val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
chkpointLoc,
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
recoverFromCheckpointLocation = recoverFromChkpoint,
trigger = trigger)
val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromChkpoint)
resultDf.createOrReplaceTempView(query.name)
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
} else if (source == SOURCE_NAME_FOREACH) {
assertNotPartitioned(SOURCE_NAME_FOREACH)
val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
} else if (source == "foreachBatch") {
assertNotPartitioned("foreachBatch")
startQuery(sink, extraOptions)
} else if (source == SOURCE_NAME_FOREACH_BATCH) {
assertNotPartitioned(SOURCE_NAME_FOREACH_BATCH)
if (trigger.isInstanceOf[ContinuousTrigger]) {
throw new AnalysisException("'foreachBatch' is not supported with continuous trigger")
throw new AnalysisException(s"'$source' is not supported with continuous trigger")
}
val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
startQuery(sink, extraOptions)
Copy link
Member

@dongjoon-hyun dongjoon-hyun Oct 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HeartSaVioR . For CaseInsensitiveMap, def toMap: Map[String, T] = originalMap. It seems that we need toMap explicitly here as we did line 385. (cc @cloud-fan )

startQuery(sink, optionsWithPath.originalMap)

Copy link
Member

@dongjoon-hyun dongjoon-hyun Oct 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, I and @cloud-fan hit case-sensitivity issues in another JIRAs due to this. Please make it sure that this PR doesn't re-introduce it because AS-IS PR switches extraOptions.toMap -> extraOptions silently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you already checked that, please add a test case for that. Or, we just use the old way extraOptions.toMap to avoid any side effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK thanks for pointing out. Nice finding. I'll just explicitly call .toMap as it was.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch @dongjoon-hyun !

} else {
val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
Expand Down Expand Up @@ -380,19 +396,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
createV1Sink(optionsWithPath)
}

df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
optionsWithPath.originalMap,
sink,
outputMode,
useTempCheckpointLocation = source == "console" || source == "noop",
recoverFromCheckpointLocation = true,
trigger = trigger)
startQuery(sink, optionsWithPath)
}
}

private def startQuery(
sink: Table,
newOptions: CaseInsensitiveMap[String],
recoverFromCheckpoint: Boolean = true): StreamingQuery = {
val useTempCheckpointLocation = SOURCES_ALLOW_ONE_TIME_QUERY.contains(source)

df.sparkSession.sessionState.streamingQueryManager.startQuery(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can follow the previous code style

...startQuery(
  newOptions.get("queryName"),
  newOptions.get("checkpointLocation"),
  df,
  newOptions. originalMap,
  ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK let me keep it as it is.

newOptions.get("queryName"),
newOptions.get("checkpointLocation"),
df,
newOptions.originalMap,
sink,
outputMode,
useTempCheckpointLocation = useTempCheckpointLocation,
recoverFromCheckpointLocation = recoverFromCheckpoint,
trigger = trigger)
}

private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = {
val ds = DataSource(
df.sparkSession,
Expand All @@ -409,7 +434,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
* @since 2.0.0
*/
def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
this.source = "foreach"
this.source = SOURCE_NAME_FOREACH
this.foreachWriter = if (writer != null) {
ds.sparkSession.sparkContext.clean(writer)
} else {
Expand All @@ -433,7 +458,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
*/
@Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = {
this.source = "foreachBatch"
this.source = SOURCE_NAME_FOREACH_BATCH
if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null")
this.foreachBatchWriter = function
this
Expand Down Expand Up @@ -485,6 +510,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName

private var tableName: String = null

private var outputMode: OutputMode = OutputMode.Append

private var trigger: Trigger = Trigger.ProcessingTime(0L)
Expand All @@ -497,3 +524,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

private var partitioningColumns: Option[Seq[String]] = None
}

object DataStreamWriter {
val SOURCE_NAME_MEMORY = "memory"
val SOURCE_NAME_FOREACH = "foreach"
val SOURCE_NAME_FOREACH_BATCH = "foreachBatch"
val SOURCE_NAME_CONSOLE = "console"
val SOURCE_NAME_TABLE = "table"
val SOURCE_NAME_NOOP = "noop"

// these writer sources are also used for one-time query, hence allow temp checkpoint location
val SOURCES_ALLOW_ONE_TIME_QUERY = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH,
SOURCE_NAME_FOREACH_BATCH, SOURCE_NAME_CONSOLE, SOURCE_NAME_NOOP)
}
Loading