Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[spark] class BlockManager(
executorId: String,
rpcEnv: RpcEnv,
val master: BlockManagerMaster,
defaultSerializer: Serializer,
val defaultSerializer: Serializer,
val conf: SparkConf,
memoryManager: MemoryManager,
mapOutputTracker: MapOutputTracker,
Expand Down Expand Up @@ -750,7 +750,7 @@ private[spark] class BlockManager(
// We will drop it to disk later if the memory store can't hold it.
val putSucceeded = if (level.deserialized) {
val values = dataDeserialize(blockId, bytes)
memoryStore.putIterator(blockId, values, level) match {
memoryStore.putIteratorAsValues(blockId, values) match {
case Right(_) => true
case Left(iter) =>
// If putting deserialized values in memory failed, we will put the bytes directly to
Expand Down Expand Up @@ -878,21 +878,40 @@ private[spark] class BlockManager(
if (level.useMemory) {
// Put it in memory first, even if it also has useDisk set to true;
// We will drop it to disk later if the memory store can't hold it.
memoryStore.putIterator(blockId, iterator(), level) match {
case Right(s) =>
size = s
case Left(iter) =>
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
dataSerializeStream(blockId, fileOutputStream, iter)
if (level.deserialized) {
memoryStore.putIteratorAsValues(blockId, iterator()) match {
case Right(s) =>
size = s
case Left(iter) =>
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
dataSerializeStream(blockId, fileOutputStream, iter)
}
size = diskStore.getSize(blockId)
} else {
iteratorFromFailedMemoryStorePut = Some(iter)
}
size = diskStore.getSize(blockId)
} else {
iteratorFromFailedMemoryStorePut = Some(iter)
}
}
} else { // !level.deserialized
memoryStore.putIteratorAsBytes(blockId, iterator()) match {
case Right(s) =>
size = s
case Left(partiallySerializedValues) =>
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
partiallySerializedValues.finishWriting(fileOutputStream)
}
size = diskStore.getSize(blockId)
} else {
iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator)
}
}
}

} else if (level.useDisk) {
diskStore.put(blockId) { fileOutputStream =>
dataSerializeStream(blockId, fileOutputStream, iterator())
Expand Down Expand Up @@ -992,7 +1011,7 @@ private[spark] class BlockManager(
// Note: if we had a means to discard the disk iterator, we would do that here.
memoryStore.getValues(blockId).get
} else {
memoryStore.putIterator(blockId, diskIterator, level) match {
memoryStore.putIteratorAsValues(blockId, diskIterator) match {
case Left(iter) =>
// The memory store put() failed, so it returned the iterator back to us:
iter
Expand Down
201 changes: 175 additions & 26 deletions core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,37 @@

package org.apache.spark.storage.memory

import java.io.OutputStream
import java.nio.ByteBuffer
import java.util.LinkedHashMap

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.google.common.io.ByteStreams

import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.memory.MemoryManager
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel}
import org.apache.spark.serializer.SerializationStream
import org.apache.spark.storage.{BlockId, BlockManager}
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.ChunkedByteBuffer
import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}

private sealed trait MemoryEntry {
val size: Long
def size: Long
}
private case class DeserializedMemoryEntry(value: Array[Any], size: Long) extends MemoryEntry
private case class SerializedMemoryEntry(buffer: ChunkedByteBuffer, size: Long) extends MemoryEntry
private case class SerializedMemoryEntry(buffer: ChunkedByteBuffer) extends MemoryEntry {
def size: Long = buffer.size
}

/**
* Stores blocks in memory, either as Arrays of deserialized Java objects or as
* serialized ByteBuffers.
*/
private[spark] class MemoryStore(
conf: SparkConf,
private[spark] class MemoryStore( conf: SparkConf,
blockManager: BlockManager,
memoryManager: MemoryManager)
extends Logging {
Expand Down Expand Up @@ -101,7 +107,7 @@ private[spark] class MemoryStore(
// We acquired enough memory for the block, so go ahead and put it
val bytes = _bytes()
assert(bytes.size == size)
val entry = new SerializedMemoryEntry(bytes, size)
val entry = new SerializedMemoryEntry(bytes)
entries.synchronized {
entries.put(blockId, entry)
}
Expand All @@ -114,7 +120,7 @@ private[spark] class MemoryStore(
}

/**
* Attempt to put the given block in memory store.
* Attempt to put the given block in memory store as values.
*
* It's possible that the iterator is too large to materialize and store in memory. To avoid
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
Expand All @@ -129,10 +135,9 @@ private[spark] class MemoryStore(
* iterator or call `close()` on it in order to free the storage memory consumed by the
* partially-unrolled block.
*/
private[storage] def putIterator(
private[storage] def putIteratorAsValues(
blockId: BlockId,
values: Iterator[Any],
level: StorageLevel): Either[PartiallyUnrolledIterator, Long] = {
values: Iterator[Any]): Either[PartiallyUnrolledIterator, Long] = {

require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")

Expand Down Expand Up @@ -186,12 +191,7 @@ private[spark] class MemoryStore(
// We successfully unrolled the entirety of this block
val arrayValues = vector.toArray
vector = null
val entry = if (level.deserialized) {
new DeserializedMemoryEntry(arrayValues, SizeEstimator.estimate(arrayValues))
} else {
val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator)
new SerializedMemoryEntry(bytes, bytes.size)
}
val entry = new DeserializedMemoryEntry(arrayValues, SizeEstimator.estimate(arrayValues))
val size = entry.size
def transferUnrollToStorage(amount: Long): Unit = {
// Synchronize so that transfer is atomic
Expand Down Expand Up @@ -223,9 +223,8 @@ private[spark] class MemoryStore(
entries.synchronized {
entries.put(blockId, entry)
}
val bytesOrValues = if (level.deserialized) "values" else "bytes"
logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format(
blockId, bytesOrValues, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed)))
logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(
blockId, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed)))
Right(size)
} else {
assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask,
Expand All @@ -244,13 +243,113 @@ private[spark] class MemoryStore(
}
}

/**
* Attempt to put the given block in memory store as bytes.
*
* It's possible that the iterator is too large to materialize and store in memory. To avoid
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
* whether there is enough free memory. If the block is successfully materialized, then the
* temporary unroll memory used during the materialization is "transferred" to storage memory,
* so we won't acquire more memory than is actually needed to store the block.
*
* @return in case of success, the estimated the estimated size of the stored data. In case of
* failure, return a handle which allows the caller to either finish the serialization
* by spilling to disk or to deserialize the partially-serialized block and reconstruct
* the original input iterator. The caller must either fully consume this result
* iterator or call `discard()` on it in order to free the storage memory consumed by the
* partially-unrolled block.
*/
private[storage] def putIteratorAsBytes(
blockId: BlockId,
values: Iterator[Any]): Either[PartiallySerializedBlock, Long] = {

require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")

// Whether there is still enough memory for us to continue unrolling this block
var keepUnrolling = true
// Initial per-task memory to request for unrolling blocks (bytes).
val initialMemoryThreshold = unrollMemoryThreshold
// Keep track of unroll memory used by this particular block / putIterator() operation
var unrollMemoryUsedByThisBlock = 0L
// Underlying buffer for unrolling the block
val redirectableStream = new RedirectableOutputStream
val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(initialMemoryThreshold.toInt)
redirectableStream.setOutputStream(byteArrayChunkOutputStream)
val serializationStream: SerializationStream = {
val ser = blockManager.defaultSerializer.newInstance()
ser.serializeStream(blockManager.wrapForCompression(blockId, redirectableStream))
}

// Request enough memory to begin unrolling
keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold)

if (!keepUnrolling) {
logWarning(s"Failed to reserve initial memory threshold of " +
s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
} else {
unrollMemoryUsedByThisBlock += initialMemoryThreshold
}

def reserveAdditionalMemoryIfNecessary(): Unit = {
if (byteArrayChunkOutputStream.size > unrollMemoryUsedByThisBlock) {
val amountToRequest = byteArrayChunkOutputStream.size - unrollMemoryUsedByThisBlock
keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
if (keepUnrolling) {
unrollMemoryUsedByThisBlock += amountToRequest
}
unrollMemoryUsedByThisBlock += amountToRequest
}
}

// Unroll this block safely, checking whether we have exceeded our threshold
while (values.hasNext && keepUnrolling) {
serializationStream.writeObject(values.next())
reserveAdditionalMemoryIfNecessary()
}

if (keepUnrolling) {
serializationStream.close()
reserveAdditionalMemoryIfNecessary()
}

if (keepUnrolling) {
val entry = SerializedMemoryEntry(
new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)))
// Synchronize so that transfer is atomic
memoryManager.synchronized {
releaseUnrollMemoryForThisTask(unrollMemoryUsedByThisBlock)
val success = memoryManager.acquireStorageMemory(blockId, entry.size)
assert(success, "transferring unroll memory to storage memory failed")
}
entries.synchronized {
entries.put(blockId, entry)
}
logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format(
blockId, Utils.bytesToString(entry.size), Utils.bytesToString(blocksMemoryUsed)))
Right(entry.size)
} else {
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, byteArrayChunkOutputStream.size)
Left(
new PartiallySerializedBlock(
this,
blockManager,
blockId,
serializationStream,
redirectableStream,
unrollMemoryUsedByThisBlock,
new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)),
values))
}
}

def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
val entry = entries.synchronized { entries.get(blockId) }
entry match {
case null => None
case e: DeserializedMemoryEntry =>
throw new IllegalArgumentException("should only call getBytes on serialized blocks")
case SerializedMemoryEntry(bytes, _) => Some(bytes)
case SerializedMemoryEntry(bytes) => Some(bytes)
}
}

Expand Down Expand Up @@ -343,7 +442,7 @@ private[spark] class MemoryStore(
if (entry != null) {
val data = entry match {
case DeserializedMemoryEntry(values, _) => Left(values)
case SerializedMemoryEntry(buffer, _) => Right(buffer)
case SerializedMemoryEntry(buffer) => Right(buffer)
}
val newEffectiveStorageLevel = blockManager.dropFromMemory(blockId, () => data)
if (newEffectiveStorageLevel.isValid) {
Expand Down Expand Up @@ -463,12 +562,13 @@ private[spark] class MemoryStore(
}

/**
* The result of a failed [[MemoryStore.putIterator()]] call.
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
* @param memoryStore the memoryStore, used for freeing memory.
* @param memoryStore the memoryStore, used for freeing memory.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to [[MemoryStore.putIterator()]].
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
*/
private[storage] class PartiallyUnrolledIterator(
memoryStore: MemoryStore,
Expand Down Expand Up @@ -500,3 +600,52 @@ private[storage] class PartiallyUnrolledIterator(
iter = null
}
}

private class RedirectableOutputStream extends OutputStream {
private[this] var os: OutputStream = _
def setOutputStream(s: OutputStream): Unit = { os = s }
override def write(b: Int): Unit = os.write(b)
override def write(b: Array[Byte]): Unit = os.write(b)
override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len)
override def flush(): Unit = os.flush()
override def close(): Unit = os.close()
}

/**
* The result of a failed [[MemoryStore.putIteratorAsBytes()]] call.
*/
private[storage] class PartiallySerializedBlock(
memoryStore: MemoryStore,
blockManager: BlockManager,
blockId: BlockId,
serializationStream: SerializationStream,
redirectableOutputStream: RedirectableOutputStream,
unrollMemory: Long,
unrolled: ChunkedByteBuffer,
iter: Iterator[Any]) {

def discard(): Unit = {
try {
serializationStream.close()
} finally {
memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
}
}

def finishWriting(os: OutputStream): Unit = {
ByteStreams.copy(unrolled.toInputStream(), os)
redirectableOutputStream.setOutputStream(os)
while (iter.hasNext) {
serializationStream.writeObject(iter.next())
}
serializationStream.close()
}

def valuesIterator: PartiallyUnrolledIterator = {
new PartiallyUnrolledIterator(
memoryStore,
unrollMemory,
unrolled = blockManager.dataDeserialize(blockId, unrolled),
rest = iter)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp

def this() = this(32)

def getCount(): Int = count

def toByteBuffer: ByteBuffer = {
return ByteBuffer.wrap(buf, 0, count)
}
Expand Down
Loading