diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e125095cf477..84ef0fcc469a 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -238,6 +238,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) // need to re-fetch it. val storageLevel = StorageLevel.MEMORY_AND_DISK if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + Utils.tryClose(obj) throw new SparkException(s"Failed to store $broadcastId in BlockManager") } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 4cc5bcb7f9ba..aa6413707daf 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -21,6 +21,7 @@ import java.io.OutputStream import java.nio.ByteBuffer import java.util.LinkedHashMap +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -384,15 +385,30 @@ private[spark] class MemoryStore( } } + private def maybeReleaseResources(resource: (BlockId, MemoryEntry[_])): Unit = { + maybeReleaseResources(resource._1, resource._2) + } + + private def maybeReleaseResources(blockId: BlockId, entry: MemoryEntry[_]): Unit = { + entry match { + case SerializedMemoryEntry(buffer, _, _) => buffer.dispose() + case DeserializedMemoryEntry(values: Array[Any], _, _) => maybeCloseValues(values, blockId) + case _ => + } + } + + private def maybeCloseValues(values: Array[Any], blockId: BlockId): Unit = { + if (blockId.isBroadcast) { + values.foreach(value => Utils.tryClose(value)) + } + } + def remove(blockId: BlockId): Boolean = memoryManager.synchronized { val entry = entries.synchronized { entries.remove(blockId) } if (entry != null) { - entry match { - case SerializedMemoryEntry(buffer, _, _) => buffer.dispose() - case _ => - } + maybeReleaseResources(blockId, entry) memoryManager.releaseStorageMemory(entry.size, entry.memoryMode) logDebug(s"Block $blockId of size ${entry.size} dropped " + s"from memory (free ${maxMemory - blocksMemoryUsed})") @@ -404,6 +420,7 @@ private[spark] class MemoryStore( def clear(): Unit = memoryManager.synchronized { entries.synchronized { + entries.asScala.foreach(maybeReleaseResources) entries.clear() } onHeapUnrollMemoryMap.clear() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 13adaa921dc2..3f9d2f98d875 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1930,6 +1930,18 @@ private[spark] object Utils extends Logging { } } + def tryClose(value: Any): Unit = { + value match { + case closable: AutoCloseable => + try { + closable.close() + } catch { + case ex: Exception => logError(s"Failed to close AutoClosable $closable", ex) + } + case _ => + } + } + /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */ def isFatalError(e: Throwable): Boolean = { e match { diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index 7274072e5049..8a7aa24fa184 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -526,4 +526,105 @@ class MemoryStoreSuite } } } + + test("[SPARK-24225]: remove should close AutoCloseable object") { + + val (store, _) = makeMemoryStore(12000) + + val id = BroadcastBlockId(0) + val tracker = new CloseTracker() + store.putIteratorAsValues(id, Iterator(tracker), ClassTag.Any) + assert(store.contains(id)) + store.remove(id) + assert(tracker.getClosed) + } + + test("[SPARK-24225]: remove should not close object if it's not a broadcast variable") { + + val (store, _) = makeMemoryStore(12000) + + val id = "a1" // This will be a test variable + val tracker = new CloseTracker() + store.putIteratorAsValues(id, Iterator(tracker), ClassTag.Any) + assert(store.contains(id)) + store.remove(id) + assert(!tracker.getClosed) + } + + test("[SPARK-24225]: remove should close AutoCloseable objects even if they throw exceptions") { + + val (store, _) = makeMemoryStore(12000) + + val id = BroadcastBlockId(0) + val tracker = new CloseTracker(true) + store.putIteratorAsValues(id, Iterator(tracker), ClassTag.Any) + assert(store.contains(id)) + store.remove(id) + assert(tracker.getClosed) + } + + test("[SPARK-24225]: clear should close AutoCloseable objects") { + + val (store, _) = makeMemoryStore(12000) + + val id = BroadcastBlockId(0) + val tracker = new CloseTracker + store.putIteratorAsValues(id, Iterator(tracker), ClassTag.Any) + assert(store.contains(id)) + store.clear() + assert(tracker.getClosed) + } + + test("[SPARK-24225]: clear should close all AutoCloseable objects put together in an iterator") { + + val (store, _) = makeMemoryStore(12000) + + val id1 = BroadcastBlockId(1) + val tracker2 = new CloseTracker + val tracker1 = new CloseTracker + store.putIteratorAsValues(id1, Iterator(tracker1, tracker2), ClassTag.Any) + assert(store.contains(id1)) + store.clear() + assert(tracker1.getClosed) + assert(tracker2.getClosed) + } + + test("[SPARK-24225]: clear should close AutoCloseable objects even if they throw exceptions") { + + val (store, _) = makeMemoryStore(12000) + + val id1 = BroadcastBlockId(1) + val id2 = BroadcastBlockId(2) + val tracker2 = new CloseTracker(true) + val tracker1 = new CloseTracker(true) + store.putIteratorAsValues(id1, Iterator(tracker1), ClassTag.Any) + store.putIteratorAsValues(id2, Iterator(tracker2), ClassTag.Any) + assert(store.contains(id1)) + assert(store.contains(id2)) + store.clear() + assert(tracker1.getClosed) + assert(tracker1.getExceptionThrown) + assert(tracker2.getClosed) + assert(tracker2.getExceptionThrown) + } +} + +private class CloseTracker (val throwsOnClosed: Boolean = false) extends AutoCloseable { + private var closed = false + private var exceptionThrown = false + + override def close(): Unit = { + closed = true + if (throwsOnClosed) { + exceptionThrown = true + throw new RuntimeException("Throwing") + } + } + def getClosed: Boolean = { + closed + } + + def getExceptionThrown: Boolean = { + exceptionThrown + } }