Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
Utils.tryClose(obj)
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.OutputStream
import java.nio.ByteBuffer
import java.util.LinkedHashMap

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
Expand Down Expand Up @@ -384,15 +385,30 @@ private[spark] class MemoryStore(
}
}

private def maybeReleaseResources(resource: (BlockId, MemoryEntry[_])): Unit = {
maybeReleaseResources(resource._1, resource._2)
}

private def maybeReleaseResources(blockId: BlockId, entry: MemoryEntry[_]): Unit = {
entry match {
case SerializedMemoryEntry(buffer, _, _) => buffer.dispose()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just make these case classes Closeable and then you can close them consistently

case DeserializedMemoryEntry(values: Array[Any], _, _) => maybeCloseValues(values, blockId)
case _ =>
}
}

private def maybeCloseValues(values: Array[Any], blockId: BlockId): Unit = {
if (blockId.isBroadcast) {
values.foreach(value => Utils.tryClose(value))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a style thing, but could be values.foreach(Utils.tryClose)

}
}

def remove(blockId: BlockId): Boolean = memoryManager.synchronized {
val entry = entries.synchronized {
entries.remove(blockId)
}
if (entry != null) {
entry match {
case SerializedMemoryEntry(buffer, _, _) => buffer.dispose()
case _ =>
}
maybeReleaseResources(blockId, entry)
memoryManager.releaseStorageMemory(entry.size, entry.memoryMode)
logDebug(s"Block $blockId of size ${entry.size} dropped " +
s"from memory (free ${maxMemory - blocksMemoryUsed})")
Expand All @@ -404,6 +420,7 @@ private[spark] class MemoryStore(

def clear(): Unit = memoryManager.synchronized {
entries.synchronized {
entries.asScala.foreach(maybeReleaseResources)
entries.clear()
}
onHeapUnrollMemoryMap.clear()
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,18 @@ private[spark] object Utils extends Logging {
}
}

def tryClose(value: Any): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should accept at best AnyRef. It doesn't really seem like we need a new global utility method for this. It's a little unusual to try closing things that aren't Closeable and we can try to rationalize that in the callers above if possible.

value match {
case closable: AutoCloseable =>
try {
closable.close()
} catch {
case ex: Exception => logError(s"Failed to close AutoClosable $closable", ex)
}
case _ =>
}
}

/** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */
def isFatalError(e: Throwable): Boolean = {
e match {
Expand Down
101 changes: 101 additions & 0 deletions core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -526,4 +526,105 @@ class MemoryStoreSuite
}
}
}

test("[SPARK-24225]: remove should close AutoCloseable object") {

val (store, _) = makeMemoryStore(12000)

val id = BroadcastBlockId(0)
val tracker = new CloseTracker()
store.putIteratorAsValues(id, Iterator(tracker), ClassTag.Any)
assert(store.contains(id))
store.remove(id)
assert(tracker.getClosed)
}

test("[SPARK-24225]: remove should not close object if it's not a broadcast variable") {

val (store, _) = makeMemoryStore(12000)

val id = "a1" // This will be a test variable
val tracker = new CloseTracker()
store.putIteratorAsValues(id, Iterator(tracker), ClassTag.Any)
assert(store.contains(id))
store.remove(id)
assert(!tracker.getClosed)
}

test("[SPARK-24225]: remove should close AutoCloseable objects even if they throw exceptions") {

val (store, _) = makeMemoryStore(12000)

val id = BroadcastBlockId(0)
val tracker = new CloseTracker(true)
store.putIteratorAsValues(id, Iterator(tracker), ClassTag.Any)
assert(store.contains(id))
store.remove(id)
assert(tracker.getClosed)
}

test("[SPARK-24225]: clear should close AutoCloseable objects") {

val (store, _) = makeMemoryStore(12000)

val id = BroadcastBlockId(0)
val tracker = new CloseTracker
store.putIteratorAsValues(id, Iterator(tracker), ClassTag.Any)
assert(store.contains(id))
store.clear()
assert(tracker.getClosed)
}

test("[SPARK-24225]: clear should close all AutoCloseable objects put together in an iterator") {

val (store, _) = makeMemoryStore(12000)

val id1 = BroadcastBlockId(1)
val tracker2 = new CloseTracker
val tracker1 = new CloseTracker
store.putIteratorAsValues(id1, Iterator(tracker1, tracker2), ClassTag.Any)
assert(store.contains(id1))
store.clear()
assert(tracker1.getClosed)
assert(tracker2.getClosed)
}

test("[SPARK-24225]: clear should close AutoCloseable objects even if they throw exceptions") {

val (store, _) = makeMemoryStore(12000)

val id1 = BroadcastBlockId(1)
val id2 = BroadcastBlockId(2)
val tracker2 = new CloseTracker(true)
val tracker1 = new CloseTracker(true)
store.putIteratorAsValues(id1, Iterator(tracker1), ClassTag.Any)
store.putIteratorAsValues(id2, Iterator(tracker2), ClassTag.Any)
assert(store.contains(id1))
assert(store.contains(id2))
store.clear()
assert(tracker1.getClosed)
assert(tracker1.getExceptionThrown)
assert(tracker2.getClosed)
assert(tracker2.getExceptionThrown)
}
}

private class CloseTracker (val throwsOnClosed: Boolean = false) extends AutoCloseable {
private var closed = false
private var exceptionThrown = false

override def close(): Unit = {
closed = true
if (throwsOnClosed) {
exceptionThrown = true
throw new RuntimeException("Throwing")
}
}
def getClosed: Boolean = {
closed
}

def getExceptionThrown: Boolean = {
exceptionThrown
}
}