From 26b0e250274fc6e6c86197887c3ced7f56ec83f2 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 12 Dec 2019 10:44:18 +0900 Subject: [PATCH 1/3] [SPARK-30227][SQL] Add close() on DataWriter interface --- .../spark/sql/kafka010/KafkaDataWriter.scala | 9 +------- .../spark/sql/connector/write/DataWriter.java | 8 ++++--- .../spark/sql/connector/InMemoryTable.scala | 2 ++ .../datasources/FileFormatDataWriter.scala | 2 ++ .../datasources/FileFormatWriter.scala | 2 ++ .../datasources/noop/NoopDataSource.scala | 1 + .../v2/WriteToDataSourceV2Exec.scala | 2 ++ .../continuous/ContinuousWriteRDD.scala | 2 ++ .../sources/ForeachWriterTable.scala | 21 +++++-------------- .../sources/PackedRowWriterFactory.scala | 10 +++++---- .../execution/streaming/sources/memory.scala | 8 ++++--- .../connector/SimpleWritableDataSource.scala | 2 ++ 12 files changed, 35 insertions(+), 34 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala index 3f8d3d2da579..cdb8ceb3e611 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala @@ -63,12 +63,5 @@ private[kafka010] class KafkaDataWriter( def abort(): Unit = {} - def close(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() - checkForErrors() - CachedKafkaProducer.close(producerParams) - } - } + def close(): Unit = {} } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java index eefe784dede4..59c69a18292d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.write; +import java.io.Closeable; import java.io.IOException; import org.apache.spark.annotation.Evolving; @@ -31,8 +32,9 @@ * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will * not be processed. If all records are successfully written, {@link #commit()} is called. * - * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, its lifecycle - * is over and Spark will not use it again. + * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, Spark will + * call {@link #close()} to let DataWriter doing resource cleanup. After calling {@link #close()}, + * its lifecycle is over and Spark will not use it again. * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to @@ -56,7 +58,7 @@ * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ @Evolving -public interface DataWriter { +public interface DataWriter extends Closeable { /** * Writes one record. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 201860e5135b..09c4b9ab9d53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -252,4 +252,6 @@ private class BufferWriter extends DataWriter[InternalRow] { override def commit(): WriterCommitMessage = buffer override def abort(): Unit = {} + + override def close(): Unit = {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index c1ebc98fb1dd..50c4f6cd57a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -86,6 +86,8 @@ abstract class FileFormatDataWriter( committer.abortTask(taskAttemptContext) } } + + override def close(): Unit = {} } /** FileFormatWriteTask for empty partitions */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index f02d9e92acb8..219c778b9164 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -277,6 +277,8 @@ object FileFormatWriter extends Logging { // If there is an error, abort the task dataWriter.abort() logError(s"Job $jobId aborted.") + }, finallyBlock = { + dataWriter.close() }) } catch { case e: FetchFailedException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 03e5f43a2a0a..dd44651050e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -72,6 +72,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] { override def write(record: InternalRow): Unit = {} override def commit(): WriterCommitMessage = null override def abort(): Unit = {} + override def close(): Unit = {} } private[noop] object NoopStreamingWrite extends StreamingWrite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 7d8a115c126e..f4c70f7593b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -467,6 +467,8 @@ object DataWritingSparkTask extends Logging { dataWriter.abort() logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId, " + s"stage $stageId.$stageAttempt)") + }, finallyBlock = { + dataWriter.close() }) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 558b4313d6d8..909dda57ee58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -80,6 +80,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDat logError(s"Writer for partition ${context.partitionId()} is aborting.") if (dataWriter != null) dataWriter.abort() logError(s"Writer for partition ${context.partitionId()} aborted.") + }, finallyBlock = { + dataWriter.close() }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 53d4bca1a5f7..04d9d5ce44d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -135,34 +135,23 @@ class ForeachDataWriter[T]( // If open returns false, we should skip writing rows. private val opened = writer.open(partitionId, epochId) - private var closeCalled: Boolean = false + private var errorOrNull: Throwable = _ override def write(record: InternalRow): Unit = { if (!opened) return - - try { - writer.process(rowConverter(record)) - } catch { - case t: Throwable => - closeWriter(t) - throw t - } + writer.process(rowConverter(record)) } override def commit(): WriterCommitMessage = { - closeWriter(null) ForeachWriterCommitMessage } override def abort(): Unit = { - closeWriter(new SparkException("Foreach writer has been aborted due to a task failure")) + errorOrNull = new SparkException("Foreach writer has been aborted due to a task failure") } - private def closeWriter(errorOrNull: Throwable): Unit = { - if (!closeCalled) { - closeCalled = true - writer.close(errorOrNull) - } + override def close(): Unit = { + writer.close(errorOrNull) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 53f56edc2768..507f860e0452 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -56,10 +56,12 @@ class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging { override def write(row: InternalRow): Unit = data.append(row.copy()) override def commit(): PackedRowCommitMessage = { - val msg = PackedRowCommitMessage(data.toArray) - data.clear() - msg + PackedRowCommitMessage(data.toArray) } - override def abort(): Unit = data.clear() + override def abort(): Unit = {} + + override def close(): Unit = { + data.clear() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index a976876b4d8e..e557148124b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -185,12 +185,14 @@ class MemoryDataWriter(partition: Int, schema: StructType) } override def commit(): MemoryWriterCommitMessage = { - val msg = MemoryWriterCommitMessage(partition, data.clone()) - data.clear() - msg + MemoryWriterCommitMessage(partition, data.clone()) } override def abort(): Unit = {} + + override def close(): Unit = { + data.clear() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 306da996e2ca..a0f1a9f9f53f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -240,4 +240,6 @@ class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] fs.delete(file, false) } } + + override def close(): Unit = {} } From 8058dbf928686242b31133a57daa1bc13b22ce84 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 12 Dec 2019 14:03:46 +0900 Subject: [PATCH 2/3] Fix UT --- .../streaming/sources/ForeachWriterTable.scala | 14 ++++++++++++-- .../sql/execution/streaming/sources/memory.scala | 8 ++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 04d9d5ce44d2..4793cb9a9b79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -139,7 +139,15 @@ class ForeachDataWriter[T]( override def write(record: InternalRow): Unit = { if (!opened) return - writer.process(rowConverter(record)) + + try { + writer.process(rowConverter(record)) + } catch { + case t: Throwable => + errorOrNull = t + throw t + } + } override def commit(): WriterCommitMessage = { @@ -147,7 +155,9 @@ class ForeachDataWriter[T]( } override def abort(): Unit = { - errorOrNull = new SparkException("Foreach writer has been aborted due to a task failure") + if (errorOrNull == null) { + errorOrNull = new SparkException("Foreach writer has been aborted due to a task failure") + } } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index e557148124b6..0cc067fc7675 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -185,14 +185,14 @@ class MemoryDataWriter(partition: Int, schema: StructType) } override def commit(): MemoryWriterCommitMessage = { - MemoryWriterCommitMessage(partition, data.clone()) + val msg = MemoryWriterCommitMessage(partition, data.clone()) + data.clear() + msg } override def abort(): Unit = {} - override def close(): Unit = { - data.clear() - } + override def close(): Unit = {} } From 21d03e75f9b669ac2cbb42af3315944fb780b2c8 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 12 Dec 2019 23:15:06 +0900 Subject: [PATCH 3/3] Rollback previous implementation of KafkaDataWriter.close() and rename to avoid conflict --- .../apache/spark/sql/kafka010/KafkaDataWriter.scala | 10 ++++++++++ .../org/apache/spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala index cdb8ceb3e611..9a2b36993361 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala @@ -64,4 +64,14 @@ private[kafka010] class KafkaDataWriter( def abort(): Unit = {} def close(): Unit = {} + + /** explicitly invalidate producer from pool. only for testing. */ + private[kafka010] def invalidateProducer(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(producerParams) + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index e2dcd6200531..ac242ba3d135 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -370,7 +370,7 @@ class KafkaContinuousSinkSuite extends KafkaSinkStreamingSuiteBase { iter.foreach(writeTask.write(_)) writeTask.commit() } finally { - writeTask.close() + writeTask.invalidateProducer() } } }