Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import com.google.common.io.ByteStreams;
import io.netty.channel.DefaultFileRegion;
import io.netty.handler.stream.ChunkedStream;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;

Expand Down Expand Up @@ -137,6 +138,12 @@ public Object convertToNetty() throws IOException {
}
}

@Override
public Object convertToNettyForSsl() throws IOException {
// Cannot use zero-copy with HTTPS
return new ChunkedStream(createInputStream(), conf.sslShuffleChunkSize());
}

public File getFile() { return file; }

public long getOffset() { return offset; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,18 @@ public abstract class ManagedBuffer {
* the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNetty() throws IOException;

/**
* Convert the buffer into a Netty object, used to write the data out with SSL encryption,
* which cannot use {@link io.netty.channel.FileRegion}.
* The return value is either a {@link io.netty.buffer.ByteBuf},
* a {@link io.netty.handler.stream.ChunkedStream}, or a {@link java.io.InputStream}.
*
* If this method returns a ByteBuf, then that buffer's reference count will be incremented and
* the caller will be responsible for releasing this new reference.
*
* Once `kernel.ssl.sendfile` and OpenSSL's `ssl_sendfile` are more widely adopted (and supported
* in Netty), we can potentially deprecate these APIs and just use `convertToNetty`.
*/
public abstract Object convertToNettyForSsl() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public Object convertToNetty() throws IOException {
return buf.duplicate().retain();
}

@Override
public Object convertToNettyForSsl() throws IOException {
return buf.duplicate().retain();
}

@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ public Object convertToNetty() throws IOException {
return Unpooled.wrappedBuffer(buf);
}

@Override
public Object convertToNettyForSsl() throws IOException {
return Unpooled.wrappedBuffer(buf);
}

@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ public boolean saslServerAlwaysEncrypt() {
return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
}

/**
* When Secure (SSL/TLS) Shuffle is enabled, the Chunk size to use for shuffling files.
*/
public int sslShuffleChunkSize() {
return Ints.checkedCast(JavaUtils.byteStringAsBytes(
conf.get("spark.network.ssl.maxEncryptedBlockSize", "64k")));
}

/**
* Flag indicating whether to share the pooled ByteBuf allocators between the different Netty
* channels. If enabled then only two pooled ByteBuf allocators are created: one where caching
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public Object convertToNetty() throws IOException {
return underlying.convertToNetty();
}

@Override
public Object convertToNettyForSsl() throws IOException {
return underlying.convertToNettyForSsl();
}

@Override
public int hashCode() {
return underlying.hashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ private[spark] trait BlockData {
*/
def toNetty(): Object

/**
* Returns a Netty-friendly wrapper for the block's data.
*
* Please see `ManagedBuffer.convertToNettyForSsl()` for more details.
*/
def toNettyForSsl(): Object

def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer

def toByteBuffer(): ByteBuffer
Expand All @@ -103,6 +110,8 @@ private[spark] class ByteBufferBlockData(

override def toNetty(): Object = buffer.toNetty

override def toNettyForSsl(): AnyRef = buffer.toNettyForSsl

override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
buffer.copy(allocator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ private[storage] class BlockManagerManagedBuffer(

override def convertToNetty(): Object = data.toNetty()

override def convertToNettyForSsl(): Object = data.toNettyForSsl()

override def retain(): ManagedBuffer = {
refCount.incrementAndGet()
val locked = blockInfoManager.lockForReading(blockId, blocking = false)
Expand Down
13 changes: 13 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/DiskStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,14 @@ private class DiskBlockData(
*/
override def toNetty(): AnyRef = new DefaultFileRegion(file, 0, size)

/**
* Returns a Netty-friendly wrapper for the block's data.
*
* Please see `ManagedBuffer.convertToNettyForSsl()` for more details.
*/
override def toNettyForSsl(): AnyRef =
toChunkedByteBuffer(ByteBuffer.allocate).toNettyForSsl

override def toChunkedByteBuffer(allocator: (Int) => ByteBuffer): ChunkedByteBuffer = {
Utils.tryWithResource(open()) { channel =>
var remaining = blockSize
Expand Down Expand Up @@ -234,6 +242,9 @@ private[spark] class EncryptedBlockData(

override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize)

override def toNettyForSsl(): AnyRef =
toChunkedByteBuffer(ByteBuffer.allocate).toNettyForSsl

override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
val source = open()
try {
Expand Down Expand Up @@ -297,6 +308,8 @@ private[spark] class EncryptedManagedBuffer(

override def convertToNetty(): AnyRef = blockData.toNetty()

override def convertToNettyForSsl(): AnyRef = blockData.toNettyForSsl()

override def createInputStream(): InputStream = blockData.toInputStream()

override def retain(): ManagedBuffer = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.channels.WritableByteChannel

import com.google.common.io.ByteStreams
import com.google.common.primitives.UnsignedBytes
import io.netty.handler.stream.ChunkedStream
import org.apache.commons.io.IOUtils

import org.apache.spark.SparkEnv
Expand Down Expand Up @@ -131,6 +132,14 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) extends Ex
new ChunkedByteBufferFileRegion(this, bufferWriteChunkSize)
}

/**
* Wrap this in a ChunkedStream which allows us to provide the data in a manner
* compatible with SSL encryption
*/
def toNettyForSsl: ChunkedStream = {
new ChunkedStream(toInputStream(), bufferWriteChunkSize)
}

/**
* Copy this buffer into a new byte array.
*
Expand Down Expand Up @@ -284,6 +293,17 @@ private[spark] class ChunkedByteBufferInputStream(
}
}

override def available(): Int = {
if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
currentChunk = chunks.next()
}
if (currentChunk != null && currentChunk.hasRemaining) {
currentChunk.remaining
} else {
0
}
}

override def read(): Int = {
if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
currentChunk = chunks.next()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits {
override def release(): ManagedBuffer = this

override def convertToNetty(): AnyRef = null

override def convertToNettyForSsl(): AnyRef = null
}
listener.onBlockFetchSuccess("block-id-unused", badBuffer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed
override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer()
override def createInputStream(): InputStream = underlyingBuffer.createInputStream()
override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty()
override def convertToNettyForSsl(): AnyRef = underlyingBuffer.convertToNettyForSsl()

override def retain(): ManagedBuffer = {
callsToRetain += 1
Expand Down