diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/EncryptedMessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/EncryptedMessageWithHeader.java new file mode 100644 index 000000000000..7e7ba85ebf66 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/EncryptedMessageWithHeader.java @@ -0,0 +1,148 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.protocol; + +import java.io.EOFException; +import java.io.InputStream; +import javax.annotation.Nullable; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.stream.ChunkedStream; +import io.netty.handler.stream.ChunkedInput; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be any InputStream or ChunkedStream + */ +public class EncryptedMessageWithHeader implements ChunkedInput { + + @Nullable private final ManagedBuffer managedBuffer; + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + /** + * Construct a new EncryptedMessageWithHeader. + * + * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to + * be passed in so that the buffer can be freed when this message is + * deallocated. Ownership of the caller's reference to this buffer is + * transferred to this class, so if the caller wants to continue to use the + * ManagedBuffer in other messages then they will need to call retain() on + * it before passing it to this constructor. + * @param header the message header. + * @param body the message body. + * @param bodyLength the length of the message body, in bytes. + */ + + public EncryptedMessageWithHeader( + @Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument(body instanceof InputStream || body instanceof ChunkedStream, + "Body must be an InputStream or a ChunkedStream."); + this.managedBuffer = managedBuffer; + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + this.totalBytesTransferred = 0; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + if (isEndOfInput()) { + return null; + } + + if (totalBytesTransferred < headerLength) { + totalBytesTransferred += headerLength; + return header.retain(); + } else if (body instanceof InputStream) { + InputStream stream = (InputStream) body; + int available = stream.available(); + if (available <= 0) { + available = (int) (length() - totalBytesTransferred); + } else { + available = (int) Math.min(available, length() - totalBytesTransferred); + } + ByteBuf buffer = allocator.buffer(available); + int toRead = Math.min(available, buffer.writableBytes()); + int read = buffer.writeBytes(stream, toRead); + if (read >= 0) { + totalBytesTransferred += read; + return buffer; + } else { + throw new EOFException("Unable to read bytes from InputStream"); + } + } else if (body instanceof ChunkedStream) { + ChunkedStream stream = (ChunkedStream) body; + long old = stream.transferredBytes(); + ByteBuf buffer = stream.readChunk(allocator); + long read = stream.transferredBytes() - old; + if (read >= 0) { + totalBytesTransferred += read; + assert(totalBytesTransferred <= length()); + return buffer; + } else { + throw new EOFException("Unable to read bytes from ChunkedStream"); + } + } else { + return null; + } + } + + @Override + public long length() { + return headerLength + bodyLength; + } + + @Override + public long progress() { + return totalBytesTransferred; + } + + @Override + public boolean isEndOfInput() throws Exception { + return (headerLength + bodyLength) == totalBytesTransferred; + } + + @Override + public void close() throws Exception { + header.release(); + if (managedBuffer != null) { + managedBuffer.release(); + } + if (body instanceof InputStream) { + ((InputStream) body).close(); + } else if (body instanceof ChunkedStream) { + ((ChunkedStream) body).close(); + } + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java new file mode 100644 index 000000000000..f43d0789ee67 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.protocol; + +import java.io.InputStream; +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.handler.stream.ChunkedStream; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode secure (SSL) server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class SslMessageEncoder extends MessageToMessageEncoder { + + private final Logger logger = LoggerFactory.getLogger(SslMessageEncoder.class); + + private SslMessageEncoder() {} + + public static final SslMessageEncoder INSTANCE = new SslMessageEncoder(); + + /** + * Encodes a Message by invoking its encode() method. For non-data messages, we will add one + * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. + * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the + * data to 'out'. + */ + @Override + public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { + Object body = null; + long bodyLength = 0; + boolean isBodyInFrame = false; + + // If the message has a body, take it out... + // For SSL, zero-copy transfer will not work, so we will check if + // the body is an InputStream, and if so, use an EncryptedMessageWithHeader + // to wrap the header+body appropriately (for thread safety). + if (in.body() != null) { + try { + bodyLength = in.body().size(); + body = in.body().convertToNettyForSsl(); + isBodyInFrame = in.isBodyInFrame(); + } catch (Exception e) { + in.body().release(); + if (in instanceof AbstractResponseMessage) { + AbstractResponseMessage resp = (AbstractResponseMessage) in; + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error(String.format("Error processing %s for client %s", + in, ctx.channel().remoteAddress()), e); + encode(ctx, resp.createFailureResponse(error), out); + } else { + throw e; + } + return; + } + } + + Message.Type msgType = in.type(); + // All messages have the frame length, message type, and message itself. The frame length + // may optionally include the length of the body data, depending on what message is being + // sent. + int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); + long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0); + ByteBuf header = ctx.alloc().buffer(headerLength); + header.writeLong(frameLength); + msgType.encode(header); + in.encode(header); + assert header.writableBytes() == 0; + + if (body != null && bodyLength > 0) { + if (body instanceof ByteBuf) { + out.add(Unpooled.wrappedBuffer(header, (ByteBuf) body)); + } else if (body instanceof InputStream || body instanceof ChunkedStream) { + // For now, assume the InputStream is doing proper chunking. + out.add(new EncryptedMessageWithHeader(in.body(), header, body, bodyLength)); + } else { + throw new IllegalArgumentException( + "Body must be a ByteBuf, ChunkedStream or an InputStream"); + } + } else { + out.add(header); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java new file mode 100644 index 000000000000..7478fa1db711 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.Random; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.stream.ChunkedStream; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +public class EncryptedMessageWithHeaderSuite { + + // Tests the case where the body is an input stream and that we manage the refcounts of the + // buffer properly + @Test + public void testInputStreamBodyFromManagedBuffer() throws Exception { + byte[] randomData = new byte[128]; + new Random().nextBytes(randomData); + ByteBuf sourceBuffer = Unpooled.copiedBuffer(randomData); + InputStream body = new ByteArrayInputStream(sourceBuffer.array()); + ByteBuf header = Unpooled.copyLong(42); + + long expectedHeaderValue = header.getLong(header.readerIndex()); + assertEquals(1, header.refCnt()); + assertEquals(1, sourceBuffer.refCnt()); + ManagedBuffer managedBuf = new NettyManagedBuffer(sourceBuffer); + + EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader( + managedBuf, header, body, managedBuf.size()); + ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + + // First read should just read the header + ByteBuf headerResult = msg.readChunk(allocator); + assertEquals(header.capacity(), headerResult.readableBytes()); + assertEquals(expectedHeaderValue, headerResult.readLong()); + assertEquals(header.capacity(), msg.progress()); + assertFalse(msg.isEndOfInput()); + + // Second read should read the body + ByteBuf bodyResult = msg.readChunk(allocator); + assertEquals(randomData.length + header.capacity(), msg.progress()); + assertTrue(msg.isEndOfInput()); + + // Validate we read it all + assertEquals(bodyResult.readableBytes(), randomData.length); + for (int i = 0; i < randomData.length; i++) { + assertEquals(bodyResult.readByte(), randomData[i]); + } + + // Closing the message should release the source buffer + msg.close(); + assertEquals(0, sourceBuffer.refCnt()); + + // The header still has a reference we got + assertEquals(1, header.refCnt()); + headerResult.release(); + assertEquals(0, header.refCnt()); + } + + // Tests the case where the body is a chunked stream and that we are fine when there is no + // input managed buffer + @Test + public void testChunkedStream() throws Exception { + int bodyLength = 129; + int chunkSize = 64; + byte[] randomData = new byte[bodyLength]; + new Random().nextBytes(randomData); + InputStream inputStream = new ByteArrayInputStream(randomData); + ChunkedStream body = new ChunkedStream(inputStream, chunkSize); + ByteBuf header = Unpooled.copyLong(42); + + long expectedHeaderValue = header.getLong(header.readerIndex()); + assertEquals(1, header.refCnt()); + + EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(null, header, body, bodyLength); + ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + + // First read should just read the header + ByteBuf headerResult = msg.readChunk(allocator); + assertEquals(header.capacity(), headerResult.readableBytes()); + assertEquals(expectedHeaderValue, headerResult.readLong()); + assertEquals(header.capacity(), msg.progress()); + assertFalse(msg.isEndOfInput()); + + // Next 2 reads should read full buffers + int readIndex = 0; + for (int i = 1; i <= 2; i++) { + ByteBuf bodyResult = msg.readChunk(allocator); + assertEquals(header.capacity() + (i*chunkSize), msg.progress()); + assertFalse(msg.isEndOfInput()); + + // Validate we read data correctly + assertEquals(bodyResult.readableBytes(), chunkSize); + assert(bodyResult.readableBytes() < (randomData.length - readIndex)); + while (bodyResult.readableBytes() > 0) { + assertEquals(bodyResult.readByte(), randomData[readIndex++]); + } + } + + // Last read should be partial + ByteBuf bodyResult = msg.readChunk(allocator); + assertEquals(header.capacity() + bodyLength, msg.progress()); + assertTrue(msg.isEndOfInput()); + + // Validate we read the byte properly + assertEquals(bodyResult.readableBytes(), 1); + assertEquals(bodyResult.readByte(), randomData[readIndex]); + + // Closing the message should close the input stream + msg.close(); + assertTrue(body.isEndOfInput()); + + // The header still has a reference we got + assertEquals(1, header.refCnt()); + headerResult.release(); + assertEquals(0, header.refCnt()); + } + + @Test + public void testByteBufIsNotSupported() throws Exception { + // Validate that ByteBufs are not supported. This test can be updated + // when we add support for them + ByteBuf header = Unpooled.copyLong(42); + assertThrows(IllegalArgumentException.class, () -> { + EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader( + null, header, header, 4); + }); + } +}