Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -145,6 +146,7 @@ class InMemoryTable(
override def capabilities: util.Set[TableCapability] = Set(
TableCapability.BATCH_READ,
TableCapability.BATCH_WRITE,
TableCapability.STREAMING_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
TableCapability.OVERWRITE_DYNAMIC,
TableCapability.TRUNCATE).asJava
Expand All @@ -169,26 +171,35 @@ class InMemoryTable(

new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite {
private var writer: BatchWrite = Append
private var streamingWriter: StreamingWrite = StreamingAppend

override def truncate(): WriteBuilder = {
assert(writer == Append)
writer = TruncateAndAppend
streamingWriter = StreamingTruncateAndAppend
this
}

override def overwrite(filters: Array[Filter]): WriteBuilder = {
assert(writer == Append)
writer = new Overwrite(filters)
streamingWriter = new StreamingNotSupportedOperation(s"overwrite ($filters)")
this
}

override def overwriteDynamicPartitions(): WriteBuilder = {
assert(writer == Append)
writer = DynamicOverwrite
streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
this
}

override def buildForBatch(): BatchWrite = writer

override def buildForStreaming(): StreamingWrite = streamingWriter match {
case exc: StreamingNotSupportedOperation => exc.throwsException()
case s => s
}
}
}

Expand Down Expand Up @@ -231,6 +242,45 @@ class InMemoryTable(
}
}

private abstract class TestStreamingWrite extends StreamingWrite {
def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
BufferedRowsWriterFactory
}

def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
}

private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite {
override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory =
throwsException()

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit =
throwsException()

override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit =
throwsException()

def throwsException[T](): T = throw new IllegalStateException("The operation " +
s"${operation} isn't supported for streaming query.")
}

private object StreamingAppend extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}
}

private object StreamingTruncateAndAppend extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
dataMap.clear
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}
}

override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
Expand Down Expand Up @@ -310,10 +360,17 @@ private class BufferedRowsReader(partition: BufferedRows) extends PartitionReade
override def close(): Unit = {}
}

private object BufferedRowsWriterFactory extends DataWriterFactory {
private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory {
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
new BufferWriter
}

override def createWriter(
partitionId: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new BufferWriter
}
}

private class BufferWriter extends DataWriter[InternalRow] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableProvider}
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
Expand All @@ -45,6 +45,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
*/
@Evolving
final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
import DataStreamWriter._

private val df = ds.toDF()

Expand Down Expand Up @@ -300,54 +301,55 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
"write files of Hive data source directly.")
}

if (source == "memory") {
assertNotPartitioned("memory")
if (source == SOURCE_NAME_TABLE) {
assertNotPartitioned(SOURCE_NAME_TABLE)

import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val originalMultipartIdentifier = df.sparkSession.sessionState.sqlParser
.parseMultipartIdentifier(tableName)
val CatalogAndIdentifier(catalog, identifier) = originalMultipartIdentifier

// Currently we don't create a logical streaming writer node in logical plan, so cannot rely
// on analyzer to resolve it. Directly lookup only for temp view to provide clearer message.
// TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
if (df.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) {
throw new AnalysisException(s"Temporary view $tableName doesn't support streaming write")
}

val tableInstance = catalog.asTableCatalog.loadTable(identifier)

import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
val sink = tableInstance match {
case t: SupportsWrite if t.supports(STREAMING_WRITE) => t
case t => throw new AnalysisException(s"Table $tableName doesn't support streaming " +
s"write - $t")
}

startQuery(sink, extraOptions)
} else if (source == SOURCE_NAME_MEMORY) {
assertNotPartitioned(SOURCE_NAME_MEMORY)
if (extraOptions.get("queryName").isEmpty) {
throw new AnalysisException("queryName must be specified for memory sink")
}
val sink = new MemorySink()
val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes))
val chkpointLoc = extraOptions.get("checkpointLocation")
val recoverFromChkpoint = outputMode == OutputMode.Complete()
val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
chkpointLoc,
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
recoverFromCheckpointLocation = recoverFromChkpoint,
trigger = trigger)
val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromChkpoint)
resultDf.createOrReplaceTempView(query.name)
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
} else if (source == SOURCE_NAME_FOREACH) {
assertNotPartitioned(SOURCE_NAME_FOREACH)
val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
} else if (source == "foreachBatch") {
assertNotPartitioned("foreachBatch")
startQuery(sink, extraOptions)
} else if (source == SOURCE_NAME_FOREACH_BATCH) {
assertNotPartitioned(SOURCE_NAME_FOREACH_BATCH)
if (trigger.isInstanceOf[ContinuousTrigger]) {
throw new AnalysisException("'foreachBatch' is not supported with continuous trigger")
throw new AnalysisException(s"'$source' is not supported with continuous trigger")
}
val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
startQuery(sink, extraOptions)
Copy link
Member

@dongjoon-hyun dongjoon-hyun Oct 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HeartSaVioR . For CaseInsensitiveMap, def toMap: Map[String, T] = originalMap. It seems that we need toMap explicitly here as we did line 385. (cc @cloud-fan )

startQuery(sink, optionsWithPath.originalMap)

Copy link
Member

@dongjoon-hyun dongjoon-hyun Oct 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, I and @cloud-fan hit case-sensitivity issues in another JIRAs due to this. Please make it sure that this PR doesn't re-introduce it because AS-IS PR switches extraOptions.toMap -> extraOptions silently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you already checked that, please add a test case for that. Or, we just use the old way extraOptions.toMap to avoid any side effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK thanks for pointing out. Nice finding. I'll just explicitly call .toMap as it was.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch @dongjoon-hyun !

} else {
val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
Expand Down Expand Up @@ -380,19 +382,30 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
createV1Sink(optionsWithPath)
}

df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
optionsWithPath.originalMap,
sink,
outputMode,
useTempCheckpointLocation = source == "console" || source == "noop",
recoverFromCheckpointLocation = true,
trigger = trigger)
startQuery(sink, optionsWithPath.originalMap)
}
}

private def startQuery(
sink: Table,
newOptions: Map[String, String],
recoverFromCheckpoint: Boolean = true): StreamingQuery = {
val queryName = extraOptions.get("queryName")
val checkpointLocation = extraOptions.get("checkpointLocation")
val useTempCheckpointLocation = SOURCES_ALLOW_ONE_TIME_QUERY.contains(source)

df.sparkSession.sessionState.streamingQueryManager.startQuery(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can follow the previous code style

...startQuery(
  newOptions.get("queryName"),
  newOptions.get("checkpointLocation"),
  df,
  newOptions. originalMap,
  ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK let me keep it as it is.

queryName,
checkpointLocation,
df,
newOptions,
sink,
outputMode,
useTempCheckpointLocation = useTempCheckpointLocation,
recoverFromCheckpointLocation = recoverFromCheckpoint,
trigger = trigger)
}

private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = {
val ds = DataSource(
df.sparkSession,
Expand All @@ -409,7 +422,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
* @since 2.0.0
*/
def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
this.source = "foreach"
this.source = SOURCE_NAME_FOREACH
this.foreachWriter = if (writer != null) {
ds.sparkSession.sparkContext.clean(writer)
} else {
Expand All @@ -433,7 +446,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
*/
@Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = {
this.source = "foreachBatch"
this.source = SOURCE_NAME_FOREACH_BATCH
if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null")
this.foreachBatchWriter = function
this
Expand All @@ -457,6 +470,17 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId))
}

/**
* Specifies the underlying output table.
*
* @since 3.1.0
*/
def table(tableName: String): DataStreamWriter[T] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird to define table as a config method. I think it's better to follow DataFrameWriter.saveAsTable and make it an action. How about

def table(tableName: String): StreamingQuery = {
  this.source = SOURCE_NAME_TABLE
  this.tableName = tableName
  start()
}

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Oct 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a bit different view on DataStreamWriter (and probably DataFrameWriter as well):

While we don't restrict the order, actually I think it's pretty much natural to have a flow, like define a sink -> set options to the sink -> set options to the streaming query -> start the query. (A couple of parts can be consolidated or the sequence can be swapped.)

df.writeStream
   .format("...")
   .option("...")
   .outputMode(...)
   .trigger(...)
   .start()

Now it looks to be simply arbitrary and something got mixed up. checkpointLocation isn't something being tied to the sink but we let end users to put into option which is also used for sink. queryName as well.

I intended the addition of table method as defining a sink, but if we'd like to care for tables specially, DataFrameWriter.insertInto would match the intention and I can change the method name to insertInto here as well. (I'm also fine to add it as saveAsTable.)

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataFrameWriterV2 enforces the flow perfectly (let's put aside the flow branch for creating table), define a sink by providing table identifier, provide options, and decide which kind of write. The flow is uni-direction and no longer be arbitrary.

I'm feeling we should also have DataStreamWriterV2 to enforce the flow as well, but let's do this first and have more time to think about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will have DataStreamWriterV2 eventually (after we figure out how to design output mode). For now, it's more important to keep API consistency between batch and stream.

I don't have a strong opinion about the naming, table is fine. cc @xuanyuanking @zsxwing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO the exact match of the method from batch side looks to be insertInto (as it will handle the output mode and simply add/update the data instead of creating a table). Just naming it to table doesn't look to say it's an action - in DataFrameWriter it's saveAsTable instead of simply table.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya. I also agree with the AS-IS config method table instead of action.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then the API works differently between DataFrameWriter and DataStreamWriter, and users may be confusing.

I'm fine with saveAsTable if it's more like action.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'll change the name to saveAsTable and call start() there.

this.source = SOURCE_NAME_TABLE
this.tableName = tableName
this
}

private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols =>
cols.map(normalize(_, "Partition"))
}
Expand Down Expand Up @@ -485,6 +509,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName

private var tableName: String = null

private var outputMode: OutputMode = OutputMode.Append

private var trigger: Trigger = Trigger.ProcessingTime(0L)
Expand All @@ -497,3 +523,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

private var partitioningColumns: Option[Seq[String]] = None
}

object DataStreamWriter {
val SOURCE_NAME_MEMORY = "memory"
val SOURCE_NAME_FOREACH = "foreach"
val SOURCE_NAME_FOREACH_BATCH = "foreachBatch"
val SOURCE_NAME_CONSOLE = "console"
val SOURCE_NAME_TABLE = "table"
val SOURCE_NAME_NOOP = "noop"

// these writer sources are also used for one-time query, hence allow temp checkpoint location
val SOURCES_ALLOW_ONE_TIME_QUERY = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH,
SOURCE_NAME_FOREACH_BATCH, SOURCE_NAME_CONSOLE, SOURCE_NAME_NOOP)
}
Loading