diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala index 3f8d3d2da579..9a2b36993361 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala @@ -63,7 +63,10 @@ private[kafka010] class KafkaDataWriter( def abort(): Unit = {} - def close(): Unit = { + def close(): Unit = {} + + /** explicitly invalidate producer from pool. only for testing. */ + private[kafka010] def invalidateProducer(): Unit = { checkForErrors() if (producer != null) { producer.flush() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index e2dcd6200531..ac242ba3d135 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -370,7 +370,7 @@ class KafkaContinuousSinkSuite extends KafkaSinkStreamingSuiteBase { iter.foreach(writeTask.write(_)) writeTask.commit() } finally { - writeTask.close() + writeTask.invalidateProducer() } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java index eefe784dede4..59c69a18292d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.write; +import java.io.Closeable; import java.io.IOException; import org.apache.spark.annotation.Evolving; @@ -31,8 +32,9 @@ * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will * not be processed. If all records are successfully written, {@link #commit()} is called. * - * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, its lifecycle - * is over and Spark will not use it again. + * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, Spark will + * call {@link #close()} to let DataWriter doing resource cleanup. After calling {@link #close()}, + * its lifecycle is over and Spark will not use it again. * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to @@ -56,7 +58,7 @@ * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ @Evolving -public interface DataWriter { +public interface DataWriter extends Closeable { /** * Writes one record. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 201860e5135b..09c4b9ab9d53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -252,4 +252,6 @@ private class BufferWriter extends DataWriter[InternalRow] { override def commit(): WriterCommitMessage = buffer override def abort(): Unit = {} + + override def close(): Unit = {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index c1ebc98fb1dd..50c4f6cd57a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -86,6 +86,8 @@ abstract class FileFormatDataWriter( committer.abortTask(taskAttemptContext) } } + + override def close(): Unit = {} } /** FileFormatWriteTask for empty partitions */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index f02d9e92acb8..219c778b9164 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -277,6 +277,8 @@ object FileFormatWriter extends Logging { // If there is an error, abort the task dataWriter.abort() logError(s"Job $jobId aborted.") + }, finallyBlock = { + dataWriter.close() }) } catch { case e: FetchFailedException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 03e5f43a2a0a..dd44651050e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -72,6 +72,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] { override def write(record: InternalRow): Unit = {} override def commit(): WriterCommitMessage = null override def abort(): Unit = {} + override def close(): Unit = {} } private[noop] object NoopStreamingWrite extends StreamingWrite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 7d8a115c126e..f4c70f7593b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -467,6 +467,8 @@ object DataWritingSparkTask extends Logging { dataWriter.abort() logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId, " + s"stage $stageId.$stageAttempt)") + }, finallyBlock = { + dataWriter.close() }) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 558b4313d6d8..909dda57ee58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -80,6 +80,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDat logError(s"Writer for partition ${context.partitionId()} is aborting.") if (dataWriter != null) dataWriter.abort() logError(s"Writer for partition ${context.partitionId()} aborted.") + }, finallyBlock = { + dataWriter.close() }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 53d4bca1a5f7..4793cb9a9b79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -135,7 +135,7 @@ class ForeachDataWriter[T]( // If open returns false, we should skip writing rows. private val opened = writer.open(partitionId, epochId) - private var closeCalled: Boolean = false + private var errorOrNull: Throwable = _ override def write(record: InternalRow): Unit = { if (!opened) return @@ -144,25 +144,24 @@ class ForeachDataWriter[T]( writer.process(rowConverter(record)) } catch { case t: Throwable => - closeWriter(t) + errorOrNull = t throw t } + } override def commit(): WriterCommitMessage = { - closeWriter(null) ForeachWriterCommitMessage } override def abort(): Unit = { - closeWriter(new SparkException("Foreach writer has been aborted due to a task failure")) + if (errorOrNull == null) { + errorOrNull = new SparkException("Foreach writer has been aborted due to a task failure") + } } - private def closeWriter(errorOrNull: Throwable): Unit = { - if (!closeCalled) { - closeCalled = true - writer.close(errorOrNull) - } + override def close(): Unit = { + writer.close(errorOrNull) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 53f56edc2768..507f860e0452 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -56,10 +56,12 @@ class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging { override def write(row: InternalRow): Unit = data.append(row.copy()) override def commit(): PackedRowCommitMessage = { - val msg = PackedRowCommitMessage(data.toArray) - data.clear() - msg + PackedRowCommitMessage(data.toArray) } - override def abort(): Unit = data.clear() + override def abort(): Unit = {} + + override def close(): Unit = { + data.clear() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index a976876b4d8e..0cc067fc7675 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -191,6 +191,8 @@ class MemoryDataWriter(partition: Int, schema: StructType) } override def abort(): Unit = {} + + override def close(): Unit = {} } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 306da996e2ca..a0f1a9f9f53f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -240,4 +240,6 @@ class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] fs.delete(file, false) } } + + override def close(): Unit = {} }