diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 66566b67870f..dd7c2061ec95 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -28,6 +28,7 @@ import com.google.common.io.ByteStreams; import io.netty.channel.DefaultFileRegion; +import io.netty.handler.stream.ChunkedStream; import org.apache.commons.lang3.builder.ToStringBuilder; import org.apache.commons.lang3.builder.ToStringStyle; @@ -137,6 +138,12 @@ public Object convertToNetty() throws IOException { } } + @Override + public Object convertToNettyForSsl() throws IOException { + // Cannot use zero-copy with HTTPS + return new ChunkedStream(createInputStream(), conf.sslShuffleChunkSize()); + } + public File getFile() { return file; } public long getOffset() { return offset; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index 4dd8cec2900c..893aa106a3fe 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -75,4 +75,18 @@ public abstract class ManagedBuffer { * the caller will be responsible for releasing this new reference. */ public abstract Object convertToNetty() throws IOException; + + /** + * Convert the buffer into a Netty object, used to write the data out with SSL encryption, + * which cannot use {@link io.netty.channel.FileRegion}. + * The return value is either a {@link io.netty.buffer.ByteBuf}, + * a {@link io.netty.handler.stream.ChunkedStream}, or a {@link java.io.InputStream}. + * + * If this method returns a ByteBuf, then that buffer's reference count will be incremented and + * the caller will be responsible for releasing this new reference. + * + * Once `kernel.ssl.sendfile` and OpenSSL's `ssl_sendfile` are more widely adopted (and supported + * in Netty), we can potentially deprecate these APIs and just use `convertToNetty`. + */ + public abstract Object convertToNettyForSsl() throws IOException; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index b42977c7cb7f..a40cfc8bc04b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -68,6 +68,11 @@ public Object convertToNetty() throws IOException { return buf.duplicate().retain(); } + @Override + public Object convertToNettyForSsl() throws IOException { + return buf.duplicate().retain(); + } + @Override public String toString() { return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index 084f89d2611c..6eb8d4e2c731 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -66,6 +66,11 @@ public Object convertToNetty() throws IOException { return Unpooled.wrappedBuffer(buf); } + @Override + public Object convertToNettyForSsl() throws IOException { + return Unpooled.wrappedBuffer(buf); + } + @Override public String toString() { return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 2794883f3cf3..b8d8f6b85a46 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -249,6 +249,14 @@ public boolean saslServerAlwaysEncrypt() { return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); } + /** + * When Secure (SSL/TLS) Shuffle is enabled, the Chunk size to use for shuffling files. + */ + public int sslShuffleChunkSize() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.network.ssl.maxEncryptedBlockSize", "64k"))); + } + /** * Flag indicating whether to share the pooled ByteBuf allocators between the different Netty * channels. If enabled then only two pooled ByteBuf allocators are created: one where caching diff --git a/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java index 83c90f9eff2b..1814634fb92a 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -80,6 +80,11 @@ public Object convertToNetty() throws IOException { return underlying.convertToNetty(); } + @Override + public Object convertToNettyForSsl() throws IOException { + return underlying.convertToNettyForSsl(); + } + @Override public int hashCode() { return underlying.hashCode(); diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index cccee78aee13..a6962c46243f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -85,6 +85,13 @@ private[spark] trait BlockData { */ def toNetty(): Object + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * Please see `ManagedBuffer.convertToNettyForSsl()` for more details. + */ + def toNettyForSsl(): Object + def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer def toByteBuffer(): ByteBuffer @@ -103,6 +110,8 @@ private[spark] class ByteBufferBlockData( override def toNetty(): Object = buffer.toNetty + override def toNettyForSsl(): AnyRef = buffer.toNettyForSsl + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { buffer.copy(allocator) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index 5c12b5cee4d2..cab11536e146 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -51,6 +51,8 @@ private[storage] class BlockManagerManagedBuffer( override def convertToNetty(): Object = data.toNetty() + override def convertToNettyForSsl(): Object = data.toNettyForSsl() + override def retain(): ManagedBuffer = { refCount.incrementAndGet() val locked = blockInfoManager.lockForReading(blockId, blocking = false) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1cb5adef5f46..54c5d0b2dce7 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -184,6 +184,14 @@ private class DiskBlockData( */ override def toNetty(): AnyRef = new DefaultFileRegion(file, 0, size) + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * Please see `ManagedBuffer.convertToNettyForSsl()` for more details. + */ + override def toNettyForSsl(): AnyRef = + toChunkedByteBuffer(ByteBuffer.allocate).toNettyForSsl + override def toChunkedByteBuffer(allocator: (Int) => ByteBuffer): ChunkedByteBuffer = { Utils.tryWithResource(open()) { channel => var remaining = blockSize @@ -234,6 +242,9 @@ private[spark] class EncryptedBlockData( override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize) + override def toNettyForSsl(): AnyRef = + toChunkedByteBuffer(ByteBuffer.allocate).toNettyForSsl + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { val source = open() try { @@ -297,6 +308,8 @@ private[spark] class EncryptedManagedBuffer( override def convertToNetty(): AnyRef = blockData.toNetty() + override def convertToNettyForSsl(): AnyRef = blockData.toNettyForSsl() + override def createInputStream(): InputStream = blockData.toInputStream() override def retain(): ManagedBuffer = this diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 73e4e72cc5bd..88bd117ba22b 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -23,6 +23,7 @@ import java.nio.channels.WritableByteChannel import com.google.common.io.ByteStreams import com.google.common.primitives.UnsignedBytes +import io.netty.handler.stream.ChunkedStream import org.apache.commons.io.IOUtils import org.apache.spark.SparkEnv @@ -131,6 +132,14 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) extends Ex new ChunkedByteBufferFileRegion(this, bufferWriteChunkSize) } + /** + * Wrap this in a ChunkedStream which allows us to provide the data in a manner + * compatible with SSL encryption + */ + def toNettyForSsl: ChunkedStream = { + new ChunkedStream(toInputStream(), bufferWriteChunkSize) + } + /** * Copy this buffer into a new byte array. * @@ -284,6 +293,17 @@ private[spark] class ChunkedByteBufferInputStream( } } + override def available(): Int = { + if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) { + currentChunk = chunks.next() + } + if (currentChunk != null && currentChunk.hasRemaining) { + currentChunk.remaining + } else { + 0 + } + } + override def read(): Int = { if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) { currentChunk = chunks.next() diff --git a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala index d7e4b9166fa0..f9a1b778b4ea 100644 --- a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala @@ -74,6 +74,8 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits { override def release(): ManagedBuffer = this override def convertToNetty(): AnyRef = null + + override def convertToNettyForSsl(): AnyRef = null } listener.onBlockFetchSuccess("block-id-unused", badBuffer) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 56b8e0b6df3f..9638558e3c93 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -43,6 +43,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() override def createInputStream(): InputStream = underlyingBuffer.createInputStream() override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() + override def convertToNettyForSsl(): AnyRef = underlyingBuffer.convertToNettyForSsl() override def retain(): ManagedBuffer = { callsToRetain += 1