diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index c6656341fcd1..95d70479ef01 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -27,7 +27,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ListBuffer import com.google.common.io.{ByteStreams, Closeables, Files} -import io.netty.channel.FileRegion +import io.netty.channel.{DefaultFileRegion, FileRegion} import io.netty.util.AbstractReferenceCounted import org.apache.spark.{SecurityManager, SparkConf} @@ -47,6 +47,8 @@ private[spark] class DiskStore( securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", + Int.MaxValue.toString) private val blockSizes = new ConcurrentHashMap[String, Long]() def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) @@ -108,25 +110,7 @@ private[spark] class DiskStore( new EncryptedBlockData(file, blockSize, conf, key) case _ => - val channel = new FileInputStream(file).getChannel() - if (blockSize < minMemoryMapBytes) { - // For small files, directly read rather than memory map. - Utils.tryWithSafeFinally { - val buf = ByteBuffer.allocate(blockSize.toInt) - JavaUtils.readFully(channel, buf) - buf.flip() - new ByteBufferBlockData(new ChunkedByteBuffer(buf), true) - } { - channel.close() - } - } else { - Utils.tryWithSafeFinally { - new ByteBufferBlockData( - new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true) - } { - channel.close() - } - } + new DiskBlockData(minMemoryMapBytes, maxMemoryMapBytes, file, blockSize) } } @@ -165,6 +149,61 @@ private[spark] class DiskStore( } +private class DiskBlockData( + minMemoryMapBytes: Long, + maxMemoryMapBytes: Long, + file: File, + blockSize: Long) extends BlockData { + + override def toInputStream(): InputStream = new FileInputStream(file) + + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * Please see `ManagedBuffer.convertToNetty()` for more details. + */ + override def toNetty(): AnyRef = new DefaultFileRegion(file, 0, size) + + override def toChunkedByteBuffer(allocator: (Int) => ByteBuffer): ChunkedByteBuffer = { + Utils.tryWithResource(open()) { channel => + var remaining = blockSize + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, maxMemoryMapBytes) + val chunk = allocator(chunkSize.toInt) + remaining -= chunkSize + JavaUtils.readFully(channel, chunk) + chunk.flip() + chunks += chunk + } + new ChunkedByteBuffer(chunks.toArray) + } + } + + override def toByteBuffer(): ByteBuffer = { + require(blockSize < maxMemoryMapBytes, + s"can't create a byte buffer of size $blockSize" + + s" since it exceeds ${Utils.bytesToString(maxMemoryMapBytes)}.") + Utils.tryWithResource(open()) { channel => + if (blockSize < minMemoryMapBytes) { + // For small files, directly read rather than memory map. + val buf = ByteBuffer.allocate(blockSize.toInt) + JavaUtils.readFully(channel, buf) + buf.flip() + buf + } else { + channel.map(MapMode.READ_ONLY, 0, file.length) + } + } + } + + override def size: Long = blockSize + + override def dispose(): Unit = {} + + private def open() = new FileInputStream(file).getChannel +} + private class EncryptedBlockData( file: File, blockSize: Long, diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 67fc084e8a13..36977d8c554a 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -50,18 +50,18 @@ class DiskStoreSuite extends SparkFunSuite { val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) - val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer + val mapped = diskStoreMapped.getBytes(blockId).toByteBuffer() assert(diskStoreMapped.remove(blockId)) val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) - val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer + val notMapped = diskStoreNotMapped.getBytes(blockId).toByteBuffer() // Not possible to do isInstanceOf due to visibility of HeapByteBuffer - assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), + assert(notMapped.getClass.getName.endsWith("HeapByteBuffer"), "Expected HeapByteBuffer for un-mapped read") - assert(mapped.getChunks().forall(_.isInstanceOf[MappedByteBuffer]), + assert(mapped.isInstanceOf[MappedByteBuffer], "Expected MappedByteBuffer for mapped read") def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = { @@ -70,8 +70,8 @@ class DiskStoreSuite extends SparkFunSuite { array } - assert(Arrays.equals(mapped.toArray, bytes)) - assert(Arrays.equals(notMapped.toArray, bytes)) + assert(Arrays.equals(new ChunkedByteBuffer(mapped).toArray, bytes)) + assert(Arrays.equals(new ChunkedByteBuffer(notMapped).toArray, bytes)) } test("block size tracking") { @@ -92,6 +92,44 @@ class DiskStoreSuite extends SparkFunSuite { assert(diskStore.getSize(blockId) === 0L) } + test("blocks larger than 2gb") { + val conf = new SparkConf() + .set("spark.storage.memoryMapLimitForTests", "10k" ) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val arr = new Array[Byte](1024) + for { + _ <- 0 until 20 + } { + val buf = ByteBuffer.wrap(arr) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + } + + val blockData = diskStore.getBytes(blockId) + assert(blockData.size == 20 * 1024) + + val chunkedByteBuffer = blockData.toChunkedByteBuffer(ByteBuffer.allocate) + val chunks = chunkedByteBuffer.chunks + assert(chunks.size === 2) + for (chunk <- chunks) { + assert(chunk.limit === 10 * 1024) + } + + val e = intercept[IllegalArgumentException]{ + blockData.toByteBuffer() + } + + assert(e.getMessage === + s"requirement failed: can't create a byte buffer of size ${blockData.size}" + + " since it exceeds 10.0 KB.") + } + test("block data encryption") { val testDir = Utils.createTempDir() val testData = new Array[Byte](128 * 1024)