diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/FlushOperation.java b/libs/nio/src/main/java/org/elasticsearch/nio/FlushOperation.java index 7a1696483db06..de0318a941af6 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/FlushOperation.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/FlushOperation.java @@ -25,6 +25,8 @@ public class FlushOperation { + private static final ByteBuffer[] EMPTY_ARRAY = new ByteBuffer[0]; + private final BiConsumer listener; private final ByteBuffer[] buffers; private final int[] offsets; @@ -61,19 +63,38 @@ public void incrementIndex(int delta) { } public ByteBuffer[] getBuffersToWrite() { + return getBuffersToWrite(length); + } + + public ByteBuffer[] getBuffersToWrite(int maxBytes) { final int index = Arrays.binarySearch(offsets, internalIndex); - int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index; + final int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index; + final int finalIndex = Arrays.binarySearch(offsets, Math.min(internalIndex + maxBytes, length)); + final int finalOffsetIndex = finalIndex < 0 ? (-(finalIndex + 1)) - 1 : finalIndex; - ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex]; + int nBuffers = (finalOffsetIndex - offsetIndex) + 1; + int firstBufferPosition = internalIndex - offsets[offsetIndex]; ByteBuffer firstBuffer = buffers[offsetIndex].duplicate(); - firstBuffer.position(internalIndex - offsets[offsetIndex]); + firstBuffer.position(firstBufferPosition); + if (nBuffers == 1 && firstBuffer.remaining() == 0) { + return EMPTY_ARRAY; + } + + ByteBuffer[] postIndexBuffers = new ByteBuffer[nBuffers]; postIndexBuffers[0] = firstBuffer; + int finalOffset = offsetIndex + nBuffers; + int nBytes = firstBuffer.remaining(); int j = 1; - for (int i = (offsetIndex + 1); i < buffers.length; ++i) { - postIndexBuffers[j++] = buffers[i].duplicate(); + for (int i = (offsetIndex + 1); i < finalOffset; ++i) { + ByteBuffer buffer = buffers[i].duplicate(); + nBytes += buffer.remaining(); + postIndexBuffers[j++] = buffer; } + int excessBytes = Math.max(0, nBytes - maxBytes); + ByteBuffer lastBuffer = postIndexBuffers[postIndexBuffers.length - 1]; + lastBuffer.limit(lastBuffer.limit() - excessBytes); return postIndexBuffers; } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java b/libs/nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java index 61c997603ff97..4855e0cbade9c 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java @@ -27,7 +27,7 @@ public class FlushReadyWrite extends FlushOperation implements WriteOperation { private final SocketChannelContext channelContext; private final ByteBuffer[] buffers; - FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer listener) { + public FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer listener) { super(buffers, listener); this.channelContext = channelContext; this.buffers = buffers; diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java b/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java index f7e6fbb768728..2dfd53d27e109 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java @@ -19,7 +19,6 @@ package org.elasticsearch.nio; -import org.elasticsearch.common.util.concurrent.AbstractRefCounted; import org.elasticsearch.nio.utils.ExceptionsHelper; import java.nio.ByteBuffer; @@ -140,11 +139,11 @@ public ByteBuffer[] sliceBuffersTo(long to) { ByteBuffer[] buffers = new ByteBuffer[pageCount]; Iterator pageIterator = pages.iterator(); - ByteBuffer firstBuffer = pageIterator.next().byteBuffer.duplicate(); + ByteBuffer firstBuffer = pageIterator.next().byteBuffer().duplicate(); firstBuffer.position(firstBuffer.position() + offset); buffers[0] = firstBuffer; for (int i = 1; i < buffers.length; i++) { - buffers[i] = pageIterator.next().byteBuffer.duplicate(); + buffers[i] = pageIterator.next().byteBuffer().duplicate(); } if (finalLimit != 0) { buffers[buffers.length - 1].limit(finalLimit); @@ -180,14 +179,14 @@ public Page[] sliceAndRetainPagesTo(long to) { Page[] pages = new Page[pageCount]; Iterator pageIterator = this.pages.iterator(); Page firstPage = pageIterator.next().duplicate(); - ByteBuffer firstBuffer = firstPage.byteBuffer; + ByteBuffer firstBuffer = firstPage.byteBuffer(); firstBuffer.position(firstBuffer.position() + offset); pages[0] = firstPage; for (int i = 1; i < pages.length; i++) { pages[i] = pageIterator.next().duplicate(); } if (finalLimit != 0) { - pages[pages.length - 1].byteBuffer.limit(finalLimit); + pages[pages.length - 1].byteBuffer().limit(finalLimit); } return pages; @@ -217,9 +216,9 @@ public ByteBuffer[] sliceBuffersFrom(long from) { ByteBuffer[] buffers = new ByteBuffer[pages.size() - pageIndex]; Iterator pageIterator = pages.descendingIterator(); for (int i = buffers.length - 1; i > 0; --i) { - buffers[i] = pageIterator.next().byteBuffer.duplicate(); + buffers[i] = pageIterator.next().byteBuffer().duplicate(); } - ByteBuffer firstPostIndexBuffer = pageIterator.next().byteBuffer.duplicate(); + ByteBuffer firstPostIndexBuffer = pageIterator.next().byteBuffer().duplicate(); firstPostIndexBuffer.position(firstPostIndexBuffer.position() + indexInPage); buffers[0] = firstPostIndexBuffer; @@ -268,53 +267,4 @@ private int pageIndex(long index) { private int indexInPage(long index) { return (int) (index & PAGE_MASK); } - - public static class Page implements AutoCloseable { - - private final ByteBuffer byteBuffer; - // This is reference counted as some implementations want to retain the byte pages by calling - // sliceAndRetainPagesTo. With reference counting we can increment the reference count, return the - // pages, and safely close them when this channel buffer is done with them. The reference count - // would be 1 at that point, meaning that the pages will remain until the implementation closes - // theirs. - private final RefCountedCloseable refCountedCloseable; - - public Page(ByteBuffer byteBuffer, Runnable closeable) { - this(byteBuffer, new RefCountedCloseable(closeable)); - } - - private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) { - this.byteBuffer = byteBuffer; - this.refCountedCloseable = refCountedCloseable; - } - - private Page duplicate() { - refCountedCloseable.incRef(); - return new Page(byteBuffer.duplicate(), refCountedCloseable); - } - - public ByteBuffer getByteBuffer() { - return byteBuffer; - } - - @Override - public void close() { - refCountedCloseable.decRef(); - } - - private static class RefCountedCloseable extends AbstractRefCounted { - - private final Runnable closeable; - - private RefCountedCloseable(Runnable closeable) { - super("byte array page"); - this.closeable = closeable; - } - - @Override - protected void closeInternal() { - closeable.run(); - } - } - } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/Page.java b/libs/nio/src/main/java/org/elasticsearch/nio/Page.java new file mode 100644 index 0000000000000..b60c1c0127919 --- /dev/null +++ b/libs/nio/src/main/java/org/elasticsearch/nio/Page.java @@ -0,0 +1,89 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import org.elasticsearch.common.util.concurrent.AbstractRefCounted; + +import java.io.Closeable; +import java.nio.ByteBuffer; + +public class Page implements Closeable { + + private final ByteBuffer byteBuffer; + // This is reference counted as some implementations want to retain the byte pages by calling + // duplicate. With reference counting we can increment the reference count, return a new page, + // and safely close the pages independently. The closeable will not be called until each page is + // released. + private final RefCountedCloseable refCountedCloseable; + + public Page(ByteBuffer byteBuffer) { + this(byteBuffer, () -> {}); + } + + public Page(ByteBuffer byteBuffer, Runnable closeable) { + this(byteBuffer, new RefCountedCloseable(closeable)); + } + + private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) { + this.byteBuffer = byteBuffer; + this.refCountedCloseable = refCountedCloseable; + } + + /** + * Duplicates this page and increments the reference count. The new page must be closed independently + * of the original page. + * + * @return the new page + */ + public Page duplicate() { + refCountedCloseable.incRef(); + return new Page(byteBuffer.duplicate(), refCountedCloseable); + } + + /** + * Returns the {@link ByteBuffer} for this page. Modifications to the limits, positions, etc of the + * buffer will also mutate this page. Call {@link ByteBuffer#duplicate()} to avoid mutating the page. + * + * @return the byte buffer + */ + public ByteBuffer byteBuffer() { + return byteBuffer; + } + + @Override + public void close() { + refCountedCloseable.decRef(); + } + + private static class RefCountedCloseable extends AbstractRefCounted { + + private final Runnable closeable; + + private RefCountedCloseable(Runnable closeable) { + super("byte array page"); + this.closeable = closeable; + } + + @Override + protected void closeInternal() { + closeable.run(); + } + } +} diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index 816f4adc8cbb1..1444422f7a7f6 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -325,7 +325,7 @@ protected int flushToChannel(FlushOperation flushOperation) throws IOException { ioBuffer.clear(); ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit())); int j = 0; - ByteBuffer[] buffers = flushOperation.getBuffersToWrite(); + ByteBuffer[] buffers = flushOperation.getBuffersToWrite(WRITE_LIMIT); while (j < buffers.length && ioBuffer.remaining() > 0) { ByteBuffer buffer = buffers[j++]; copyBytes(buffer, ioBuffer); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java index 0591abdd69a97..c98e7dc8dfb29 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java @@ -31,6 +31,7 @@ import java.util.function.Consumer; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -168,7 +169,7 @@ public void testQueuedWriteIsFlushedInFlushCall() throws Exception { assertTrue(context.readyForFlush()); - when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers); when(flushOperation.isFullyFlushed()).thenReturn(false, true); when(flushOperation.getListener()).thenReturn(listener); context.flushChannel(); @@ -187,7 +188,7 @@ public void testPartialFlush() throws IOException { assertTrue(context.readyForFlush()); when(flushOperation.isFullyFlushed()).thenReturn(false); - when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); + when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); context.flushChannel(); verify(listener, times(0)).accept(null, null); @@ -201,8 +202,8 @@ public void testMultipleWritesPartialFlushes() throws IOException { BiConsumer listener2 = mock(BiConsumer.class); FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class); FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class); - when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); - when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); + when(flushOperation1.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); + when(flushOperation2.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); when(flushOperation1.getListener()).thenReturn(listener); when(flushOperation2.getListener()).thenReturn(listener2); @@ -237,7 +238,7 @@ public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { assertTrue(context.readyForFlush()); IOException exception = new IOException(); - when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception); when(flushOperation.getListener()).thenReturn(listener); expectThrows(IOException.class, () -> context.flushChannel()); @@ -252,7 +253,7 @@ public void testWriteIOExceptionMeansChannelReadyToClose() throws IOException { context.queueWriteOperation(flushOperation); IOException exception = new IOException(); - when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception); assertFalse(context.selectorShouldClose()); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java index 4f2a320ad583d..73dba34cc30f7 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java @@ -65,29 +65,45 @@ public void testMultipleFlushesWithCompositeBuffer() throws IOException { ByteBuffer[] byteBuffers = writeOp.getBuffersToWrite(); assertEquals(3, byteBuffers.length); assertEquals(5, byteBuffers[0].remaining()); + ByteBuffer[] byteBuffersWithLimit = writeOp.getBuffersToWrite(10); + assertEquals(2, byteBuffersWithLimit.length); + assertEquals(5, byteBuffersWithLimit[0].remaining()); + assertEquals(5, byteBuffersWithLimit[1].remaining()); writeOp.incrementIndex(5); assertFalse(writeOp.isFullyFlushed()); byteBuffers = writeOp.getBuffersToWrite(); assertEquals(2, byteBuffers.length); assertEquals(15, byteBuffers[0].remaining()); + assertEquals(3, byteBuffers[1].remaining()); + byteBuffersWithLimit = writeOp.getBuffersToWrite(10); + assertEquals(1, byteBuffersWithLimit.length); + assertEquals(10, byteBuffersWithLimit[0].remaining()); writeOp.incrementIndex(2); assertFalse(writeOp.isFullyFlushed()); byteBuffers = writeOp.getBuffersToWrite(); assertEquals(2, byteBuffers.length); assertEquals(13, byteBuffers[0].remaining()); + assertEquals(3, byteBuffers[1].remaining()); + byteBuffersWithLimit = writeOp.getBuffersToWrite(10); + assertEquals(1, byteBuffersWithLimit.length); + assertEquals(10, byteBuffersWithLimit[0].remaining()); writeOp.incrementIndex(15); assertFalse(writeOp.isFullyFlushed()); byteBuffers = writeOp.getBuffersToWrite(); assertEquals(1, byteBuffers.length); assertEquals(1, byteBuffers[0].remaining()); + byteBuffersWithLimit = writeOp.getBuffersToWrite(10); + assertEquals(1, byteBuffersWithLimit.length); + assertEquals(1, byteBuffersWithLimit[0].remaining()); writeOp.incrementIndex(1); assertTrue(writeOp.isFullyFlushed()); byteBuffers = writeOp.getBuffersToWrite(); - assertEquals(1, byteBuffers.length); - assertEquals(0, byteBuffers[0].remaining()); + assertEquals(0, byteBuffers.length); + byteBuffersWithLimit = writeOp.getBuffersToWrite(10); + assertEquals(0, byteBuffersWithLimit.length); } } diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java index 8917bec39f17e..f558043095372 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java @@ -30,8 +30,8 @@ public class InboundChannelBufferTests extends ESTestCase { private static final int PAGE_SIZE = PageCacheRecycler.PAGE_SIZE_IN_BYTES; - private final Supplier defaultPageSupplier = () -> - new InboundChannelBuffer.Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> { + private final Supplier defaultPageSupplier = () -> + new Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> { }); public void testNewBufferNoPages() { @@ -126,10 +126,10 @@ public void testIncrementIndexWithOffset() { public void testReleaseClosesPages() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + Supplier supplier = () -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); }; InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); channelBuffer.ensureCapacity(PAGE_SIZE * 4); @@ -153,10 +153,10 @@ public void testReleaseClosesPages() { public void testClose() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + Supplier supplier = () -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); }; InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); channelBuffer.ensureCapacity(PAGE_SIZE * 4); @@ -178,10 +178,10 @@ public void testClose() { public void testCloseRetainedPages() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + Supplier supplier = () -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); }; InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); channelBuffer.ensureCapacity(PAGE_SIZE * 4); @@ -192,7 +192,7 @@ public void testCloseRetainedPages() { assertFalse(closedRef.get()); } - InboundChannelBuffer.Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2); + Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2); pages[1].close(); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java index 345c5197c76b8..baf7abac79d1b 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -285,7 +285,7 @@ public void testCloseClosesChannelBuffer() throws IOException { when(channel.getRawChannel()).thenReturn(realChannel); when(channel.isOpen()).thenReturn(true); Runnable closer = mock(Runnable.class); - Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); + Supplier pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), closer); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); buffer.ensureCapacity(1); TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java index c221fdf1378d7..96db559e60333 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java @@ -29,7 +29,7 @@ import io.netty.channel.embedded.EmbeddedChannel; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.nio.FlushOperation; -import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.WriteOperation; import java.nio.ByteBuffer; @@ -97,7 +97,7 @@ public int read(ByteBuffer[] buffers) { return byteBuf.readerIndex() - initialReaderIndex; } - public int read(InboundChannelBuffer.Page[] pages) { + public int read(Page[] pages) { ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages); int readableBytes = byteBuf.readableBytes(); nettyChannel.writeInbound(byteBuf); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java index a5f274c7ccd34..57936ff70c628 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -43,6 +43,7 @@ import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.rest.RestUtils; @@ -205,9 +206,9 @@ private HttpChannelFactory() { @Override public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioHttpChannel httpChannel = new NioHttpChannel(channel); - java.util.function.Supplier pageSupplier = () -> { + java.util.function.Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this, handlingSettings, corsConfig); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/PagedByteBuf.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/PagedByteBuf.java index 40f3aeecfbc94..359926d43f9a7 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/PagedByteBuf.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/PagedByteBuf.java @@ -24,7 +24,7 @@ import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledHeapByteBuf; -import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.Page; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -39,7 +39,7 @@ private PagedByteBuf(byte[] array, Runnable releasable) { this.releasable = releasable; } - static ByteBuf byteBufFromPages(InboundChannelBuffer.Page[] pages) { + static ByteBuf byteBufFromPages(Page[] pages) { int componentCount = pages.length; if (componentCount == 0) { return Unpooled.EMPTY_BUFFER; @@ -48,15 +48,15 @@ static ByteBuf byteBufFromPages(InboundChannelBuffer.Page[] pages) { } else { int maxComponents = Math.max(16, componentCount); final List components = new ArrayList<>(componentCount); - for (InboundChannelBuffer.Page page : pages) { + for (Page page : pages) { components.add(byteBufFromPage(page)); } return new CompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, false, maxComponents, components); } } - private static ByteBuf byteBufFromPage(InboundChannelBuffer.Page page) { - ByteBuffer buffer = page.getByteBuffer(); + private static ByteBuf byteBufFromPage(Page page) { + ByteBuffer buffer = page.byteBuffer(); assert buffer.isDirect() == false && buffer.hasArray() : "Must be a heap buffer with an array"; int offset = buffer.arrayOffset() + buffer.position(); PagedByteBuf newByteBuf = new PagedByteBuf(buffer.array(), page::close); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 17ab3a5bf3d8c..30b4b4913128d 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -36,6 +36,7 @@ import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpTransport; @@ -157,9 +158,9 @@ private TcpChannelFactoryImpl(ProfileSettings profileSettings, boolean isClient) @Override public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { + Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this); Consumer exceptionHandler = (e) -> onException(nioChannel, e); diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/PagedByteBufTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/PagedByteBufTests.java index 15bd18ecf6959..df4bf3274b3bc 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/PagedByteBufTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/PagedByteBufTests.java @@ -20,7 +20,7 @@ package org.elasticsearch.http.nio; import io.netty.buffer.ByteBuf; -import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.Page; import org.elasticsearch.test.ESTestCase; import java.nio.ByteBuffer; @@ -32,12 +32,12 @@ public class PagedByteBufTests extends ESTestCase { public void testReleasingPage() { AtomicInteger integer = new AtomicInteger(0); int pageCount = randomInt(10) + 1; - ArrayList pages = new ArrayList<>(); + ArrayList pages = new ArrayList<>(); for (int i = 0; i < pageCount; ++i) { - pages.add(new InboundChannelBuffer.Page(ByteBuffer.allocate(10), integer::incrementAndGet)); + pages.add(new Page(ByteBuffer.allocate(10), integer::incrementAndGet)); } - ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages.toArray(new InboundChannelBuffer.Page[0])); + ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages.toArray(new Page[0])); assertEquals(0, integer.get()); byteBuf.retain(); @@ -62,9 +62,9 @@ public void testBytesAreUsed() { bytes2[i - 10] = (byte) i; } - InboundChannelBuffer.Page[] pages = new InboundChannelBuffer.Page[2]; - pages[0] = new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes1), () -> {}); - pages[1] = new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes2), () -> {}); + Page[] pages = new Page[2]; + pages[0] = new Page(ByteBuffer.wrap(bytes1), () -> {}); + pages[1] = new Page(ByteBuffer.wrap(bytes2), () -> {}); ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages); assertEquals(20, byteBuf.readableBytes()); @@ -73,13 +73,13 @@ public void testBytesAreUsed() { assertEquals((byte) i, byteBuf.getByte(i)); } - InboundChannelBuffer.Page[] pages2 = new InboundChannelBuffer.Page[2]; + Page[] pages2 = new Page[2]; ByteBuffer firstBuffer = ByteBuffer.wrap(bytes1); firstBuffer.position(2); ByteBuffer secondBuffer = ByteBuffer.wrap(bytes2); secondBuffer.limit(8); - pages2[0] = new InboundChannelBuffer.Page(firstBuffer, () -> {}); - pages2[1] = new InboundChannelBuffer.Page(secondBuffer, () -> {}); + pages2[0] = new Page(firstBuffer, () -> {}); + pages2[1] = new Page(secondBuffer, () -> {}); ByteBuf byteBuf2 = PagedByteBuf.byteBufFromPages(pages2); assertEquals(16, byteBuf2.readableBytes()); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index db9b1cfe74a71..66ff33213f605 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -41,6 +41,7 @@ import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectionProfile; @@ -191,9 +192,9 @@ private MockTcpChannelFactory(boolean isClient, ProfileSettings profileSettings, @Override public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { + Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this); BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e), diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index b5d5db2166c1f..2c00dd7092950 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -10,6 +10,7 @@ import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.NioSelector; @@ -17,6 +18,8 @@ import javax.net.ssl.SSLEngine; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -34,6 +37,8 @@ public final class SSLChannelContext extends SocketChannelContext { private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {}; private final SSLDriver sslDriver; + private final SSLOutboundBuffer outboundBuffer; + private FlushOperation encryptedFlush; private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER; SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, @@ -46,6 +51,8 @@ public final class SSLChannelContext extends SocketChannelContext { Predicate allowChannelPredicate) { super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate); this.sslDriver = sslDriver; + // TODO: When the bytes are actually recycled, we need to test that they are released on context close + this.outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); } @Override @@ -72,34 +79,32 @@ public void flushChannel() throws IOException { return; } // If there is currently data in the outbound write buffer, flush the buffer. - if (sslDriver.hasFlushPending()) { + if (pendingChannelFlush()) { // If the data is not completely flushed, exit. We cannot produce new write data until the // existing data has been fully flushed. - flushToChannel(sslDriver.getNetworkWriteBuffer()); - if (sslDriver.hasFlushPending()) { + flushEncryptedOperation(); + if (pendingChannelFlush()) { return; } } // If the driver is ready for application writes, we can attempt to proceed with any queued writes. if (sslDriver.readyForApplicationWrites()) { - FlushOperation currentFlush; - while (sslDriver.hasFlushPending() == false && (currentFlush = getPendingFlush()) != null) { - // If the current operation has been fully consumed (encrypted) we now know that it has been - // sent (as we only get to this point if the write buffer has been fully flushed). - if (currentFlush.isFullyFlushed()) { + FlushOperation unencryptedFlush; + while (pendingChannelFlush() == false && (unencryptedFlush = getPendingFlush()) != null) { + if (unencryptedFlush.isFullyFlushed()) { currentFlushOperationComplete(); } else { try { // Attempt to encrypt application write data. The encrypted data ends up in the // outbound write buffer. - int bytesEncrypted = sslDriver.applicationWrite(currentFlush.getBuffersToWrite()); - if (bytesEncrypted == 0) { + sslDriver.write(unencryptedFlush, outboundBuffer); + if (outboundBuffer.hasEncryptedBytesToFlush() == false) { break; } - currentFlush.incrementIndex(bytesEncrypted); + encryptedFlush = outboundBuffer.buildNetworkFlushOperation(); // Flush the write buffer to the channel - flushToChannel(sslDriver.getNetworkWriteBuffer()); + flushEncryptedOperation(); } catch (IOException e) { currentFlushOperationFailed(e); throw e; @@ -109,23 +114,38 @@ public void flushChannel() throws IOException { } else { // We are not ready for application writes, check if the driver has non-application writes. We // only want to continue producing new writes if the outbound write buffer is fully flushed. - while (sslDriver.hasFlushPending() == false && sslDriver.needsNonApplicationWrite()) { - sslDriver.nonApplicationWrite(); + while (pendingChannelFlush() == false && sslDriver.needsNonApplicationWrite()) { + sslDriver.nonApplicationWrite(outboundBuffer); // If non-application writes were produced, flush the outbound write buffer. - if (sslDriver.hasFlushPending()) { - flushToChannel(sslDriver.getNetworkWriteBuffer()); + if (outboundBuffer.hasEncryptedBytesToFlush()) { + encryptedFlush = outboundBuffer.buildNetworkFlushOperation(); + flushEncryptedOperation(); } } } } + private void flushEncryptedOperation() throws IOException { + try { + flushToChannel(encryptedFlush); + if (encryptedFlush.isFullyFlushed()) { + getSelector().executeListener(encryptedFlush.getListener(), null); + encryptedFlush = null; + } + } catch (IOException e) { + getSelector().executeFailedListener(encryptedFlush.getListener(), e); + encryptedFlush = null; + throw e; + } + } + @Override public boolean readyForFlush() { getSelector().assertOnSelectorThread(); if (sslDriver.readyForApplicationWrites()) { - return sslDriver.hasFlushPending() || super.readyForFlush(); + return pendingChannelFlush() || super.readyForFlush(); } else { - return sslDriver.hasFlushPending() || sslDriver.needsNonApplicationWrite(); + return pendingChannelFlush() || sslDriver.needsNonApplicationWrite(); } } @@ -149,7 +169,7 @@ public int read() throws IOException { @Override public boolean selectorShouldClose() { - return closeNow() || sslDriver.isClosed(); + return closeNow() || (sslDriver.isClosed() && pendingChannelFlush() == false); } @Override @@ -170,7 +190,10 @@ public void closeFromSelector() throws IOException { getSelector().assertOnSelectorThread(); if (channel.isOpen()) { closeTimeoutCanceller.run(); - IOUtils.close(super::closeFromSelector, sslDriver::close); + if (encryptedFlush != null) { + getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException()); + } + IOUtils.close(super::closeFromSelector, outboundBuffer::close, sslDriver::close); } } @@ -184,9 +207,14 @@ private void channelCloseTimeout() { getSelector().queueChannelClose(channel); } + private boolean pendingChannelFlush() { + return encryptedFlush != null; + } + private static class CloseNotifyOperation implements WriteOperation { - private static final BiConsumer LISTENER = (v, t) -> {}; + private static final BiConsumer LISTENER = (v, t) -> { + }; private static final Object WRITE_OBJECT = new Object(); private final SocketChannelContext channelContext; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java index 93978bcc6a359..4dbf1d1f03fdf 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.security.transport.nio; +import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.utils.ExceptionsHelper; @@ -29,19 +30,17 @@ * the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close * or handshake process. * - * Producing writes for a channel is more complicated. If there is existing data in the outbound write buffer - * as indicated by {@link #hasFlushPending()}, that data must be written to the channel before more outbound - * data can be produced. If no flushes are pending, {@link #needsNonApplicationWrite()} can be called to - * determine if this driver needs to produce more data to advance the handshake or close process. If that - * method returns true, {@link #nonApplicationWrite()} should be called (and the data produced then flushed - * to the channel) until no further non-application writes are needed. + * Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be + * called to determine if this driver needs to produce more data to advance the handshake or close process. + * If that method returns true, {@link #nonApplicationWrite(SSLOutboundBuffer)} should be called (and the + * data produced then flushed to the channel) until no further non-application writes are needed. * * If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine * if the driver is ready to consume application data. (Note: It is possible that * {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the * driver is waiting on non-application data from the peer.) If the driver indicates it is ready for - * application writes, {@link #applicationWrite(ByteBuffer[])} can be called. This method will encrypt - * application data and place it in the write buffer for flushing to a channel. + * application writes, {@link #write(FlushOperation, SSLOutboundBuffer)} can be called. This method will + * encrypt flush operation application data and place it in the outbound buffer for flushing to a channel. * * If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the * driver will start producing non-application writes related to notifying the peer connection that this @@ -50,23 +49,23 @@ */ public class SSLDriver implements AutoCloseable { - private static final ByteBuffer[] EMPTY_BUFFER_ARRAY = new ByteBuffer[0]; + private static final ByteBuffer[] EMPTY_BUFFERS = {ByteBuffer.allocate(0)}; + private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {}); private final SSLEngine engine; private final boolean isClientMode; // This should only be accessed by the network thread associated with this channel, so nothing needs to // be volatile. private Mode currentMode = new HandshakeMode(); - private ByteBuffer networkWriteBuffer; private ByteBuffer networkReadBuffer; + private int packetSize; public SSLDriver(SSLEngine engine, boolean isClientMode) { this.engine = engine; this.isClientMode = isClientMode; SSLSession session = engine.getSession(); - this.networkReadBuffer = ByteBuffer.allocate(session.getPacketBufferSize()); - this.networkWriteBuffer = ByteBuffer.allocate(session.getPacketBufferSize()); - this.networkWriteBuffer.position(this.networkWriteBuffer.limit()); + packetSize = session.getPacketBufferSize(); + this.networkReadBuffer = ByteBuffer.allocate(packetSize); } public void init() throws SSLException { @@ -100,18 +99,10 @@ public SSLEngine getSSLEngine() { return engine; } - public boolean hasFlushPending() { - return networkWriteBuffer.hasRemaining(); - } - public boolean isHandshaking() { return currentMode.isHandshake(); } - public ByteBuffer getNetworkWriteBuffer() { - return networkWriteBuffer; - } - public ByteBuffer getNetworkReadBuffer() { return networkReadBuffer; } @@ -134,15 +125,14 @@ public boolean needsNonApplicationWrite() { return currentMode.needsNonApplicationWrite(); } - public int applicationWrite(ByteBuffer[] buffers) throws SSLException { - assert readyForApplicationWrites() : "Should not be called if driver is not ready for application writes"; - return currentMode.write(buffers); + public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + return currentMode.write(applicationBytes, outboundBuffer); } - public void nonApplicationWrite() throws SSLException { + public void nonApplicationWrite(SSLOutboundBuffer outboundBuffer) throws SSLException { assert currentMode.isApplication() == false : "Should not be called if driver is in application mode"; if (currentMode.isApplication() == false) { - currentMode.write(EMPTY_BUFFER_ARRAY); + currentMode.write(EMPTY_FLUSH_OPERATION, outboundBuffer); } else { throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName()); } @@ -205,45 +195,36 @@ private SSLEngineResult unwrap(InboundChannelBuffer buffer) throws SSLException } } - private SSLEngineResult wrap(ByteBuffer[] buffers) throws SSLException { - assert hasFlushPending() == false : "Should never called with pending writes"; + private SSLEngineResult wrap(SSLOutboundBuffer outboundBuffer) throws SSLException { + return wrap(outboundBuffer, EMPTY_FLUSH_OPERATION); + } - networkWriteBuffer.clear(); + private SSLEngineResult wrap(SSLOutboundBuffer outboundBuffer, FlushOperation applicationBytes) throws SSLException { + ByteBuffer[] buffers = applicationBytes.getBuffersToWrite(engine.getSession().getApplicationBufferSize()); while (true) { SSLEngineResult result; + ByteBuffer networkBuffer = outboundBuffer.nextWriteBuffer(packetSize); try { - if (buffers.length == 1) { - result = engine.wrap(buffers[0], networkWriteBuffer); - } else { - result = engine.wrap(buffers, networkWriteBuffer); - } + result = engine.wrap(buffers, networkBuffer); } catch (SSLException e) { - networkWriteBuffer.position(networkWriteBuffer.limit()); + outboundBuffer.incrementEncryptedBytes(0); throw e; } + outboundBuffer.incrementEncryptedBytes(result.bytesProduced()); + applicationBytes.incrementIndex(result.bytesConsumed()); switch (result.getStatus()) { case OK: - networkWriteBuffer.flip(); return result; case BUFFER_UNDERFLOW: throw new IllegalStateException("Should not receive BUFFER_UNDERFLOW on WRAP"); case BUFFER_OVERFLOW: - // There is not enough space in the network buffer for an entire SSL packet. Expand the - // buffer if it's smaller than the current session packet size. Otherwise return and wait - // for existing data to be flushed. - int currentCapacity = networkWriteBuffer.capacity(); - ensureNetworkWriteBufferSize(); - if (currentCapacity == networkWriteBuffer.capacity()) { - return result; - } + packetSize = engine.getSession().getPacketBufferSize(); + // There is not enough space in the network buffer for an entire SSL packet. We will + // allocate a buffer with the correct packet size the next time through the loop. break; case CLOSED: - if (result.bytesProduced() > 0) { - networkWriteBuffer.flip(); - } else { - assert false : "WRAP during close processing should produce close message."; - } + assert result.bytesProduced() > 0 : "WRAP during close processing should produce close message."; return result; default: throw new IllegalStateException("Unexpected WRAP result: " + result.getStatus()); @@ -265,23 +246,12 @@ private void ensureApplicationBufferSize(InboundChannelBuffer applicationBuffer) } } - private void ensureNetworkWriteBufferSize() { - networkWriteBuffer = ensureNetBufferSize(networkWriteBuffer); - } - private void ensureNetworkReadBufferSize() { - networkReadBuffer = ensureNetBufferSize(networkReadBuffer); - } - - private ByteBuffer ensureNetBufferSize(ByteBuffer current) { - int networkPacketSize = engine.getSession().getPacketBufferSize(); - if (current.capacity() < networkPacketSize) { - ByteBuffer newBuffer = ByteBuffer.allocate(networkPacketSize); - current.flip(); - newBuffer.put(current); - return newBuffer; - } else { - return current; + packetSize = engine.getSession().getPacketBufferSize(); + if (networkReadBuffer.capacity() < packetSize) { + ByteBuffer newBuffer = ByteBuffer.allocate(packetSize); + networkReadBuffer.flip(); + newBuffer.put(networkReadBuffer); } } @@ -306,7 +276,7 @@ private interface Mode { void read(InboundChannelBuffer buffer) throws SSLException; - int write(ByteBuffer[] buffers) throws SSLException; + int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException; boolean needsNonApplicationWrite(); @@ -329,7 +299,7 @@ private void startHandshake() throws SSLException { if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) { try { - handshake(); + handshake(null); } catch (SSLException e) { closingInternal(); throw e; @@ -337,7 +307,7 @@ private void startHandshake() throws SSLException { } } - private void handshake() throws SSLException { + private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException { boolean continueHandshaking = true; while (continueHandshaking) { switch (handshakeStatus) { @@ -346,11 +316,13 @@ private void handshake() throws SSLException { continueHandshaking = false; break; case NEED_WRAP: - if (hasFlushPending() == false) { - handshakeStatus = wrap(EMPTY_BUFFER_ARRAY).getHandshakeStatus(); - } - // If we need NEED_TASK we should run the tasks immediately - if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) { + if (outboundBuffer != null) { + handshakeStatus = wrap(outboundBuffer).getHandshakeStatus(); + // If we need NEED_TASK we should run the tasks immediately + if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) { + continueHandshaking = false; + } + } else { continueHandshaking = false; } break; @@ -379,7 +351,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException { try { SSLEngineResult result = unwrap(buffer); handshakeStatus = result.getHandshakeStatus(); - handshake(); + handshake(null); // If we are done handshaking we should exit the handshake read continueUnwrap = result.bytesConsumed() > 0 && currentMode.isHandshake(); } catch (SSLException e) { @@ -390,9 +362,9 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(ByteBuffer[] buffers) throws SSLException { + public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { try { - handshake(); + handshake(outboundBuffer); } catch (SSLException e) { closingInternal(); throw e; @@ -445,8 +417,7 @@ private void maybeFinishHandshake() { String message = "Expected to be in handshaking/closed mode. Instead in application mode."; throw new AssertionError(message); } - } else if (hasFlushPending() == false) { - // We only acknowledge that we are done handshaking if there are no bytes that need to be written + } else { if (currentMode.isHandshake()) { currentMode = new ApplicationMode(); } else { @@ -473,10 +444,17 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(ByteBuffer[] buffers) throws SSLException { - SSLEngineResult result = wrap(buffers); - maybeRenegotiation(result.getHandshakeStatus()); - return result.bytesConsumed(); + public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + boolean continueWrap = true; + int totalBytesProduced = 0; + while (continueWrap && applicationBytes.isFullyFlushed() == false) { + SSLEngineResult result = wrap(outboundBuffer, applicationBytes); + int bytesProduced = result.bytesProduced(); + totalBytesProduced += bytesProduced; + boolean renegotiationRequested = maybeRenegotiation(result.getHandshakeStatus()); + continueWrap = bytesProduced > 0 && renegotiationRequested == false; + } + return totalBytesProduced; } private boolean maybeRenegotiation(SSLEngineResult.HandshakeStatus newStatus) throws SSLException { @@ -560,18 +538,19 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(ByteBuffer[] buffers) throws SSLException { - if (hasFlushPending() == false && engine.isOutboundDone()) { - needToSendClose = false; - // Close inbound if it is still open and we have decided not to wait for response. - if (needToReceiveClose == false && engine.isInboundDone() == false) { - closeInboundAndSwallowPeerDidNotCloseException(); + public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + int bytesProduced = 0; + if (engine.isOutboundDone() == false) { + bytesProduced += wrap(outboundBuffer).bytesProduced(); + if (engine.isOutboundDone()) { + needToSendClose = false; + // Close inbound if it is still open and we have decided not to wait for response. + if (needToReceiveClose == false && engine.isInboundDone() == false) { + closeInboundAndSwallowPeerDidNotCloseException(); + } } - } else { - wrap(EMPTY_BUFFER_ARRAY); - assert hasFlushPending() : "Should have produced close message"; } - return 0; + return bytesProduced; } @Override diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLOutboundBuffer.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLOutboundBuffer.java new file mode 100644 index 0000000000000..2cd28f7d7dc32 --- /dev/null +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLOutboundBuffer.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.security.transport.nio; + +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.nio.FlushOperation; +import org.elasticsearch.nio.Page; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.function.IntFunction; + +public class SSLOutboundBuffer implements AutoCloseable { + + private final ArrayDeque pages; + private final IntFunction pageSupplier; + + private Page currentPage; + + SSLOutboundBuffer(IntFunction pageSupplier) { + this.pages = new ArrayDeque<>(); + this.pageSupplier = pageSupplier; + } + + void incrementEncryptedBytes(int encryptedBytesProduced) { + if (encryptedBytesProduced != 0) { + currentPage.byteBuffer().limit(encryptedBytesProduced); + pages.addLast(currentPage); + } + currentPage = null; + } + + ByteBuffer nextWriteBuffer(int networkBufferSize) { + if (currentPage != null) { + // If there is an existing page, close it as it wasn't large enough to accommodate the SSLEngine. + currentPage.close(); + } + + Page newPage = pageSupplier.apply(networkBufferSize); + currentPage = newPage; + return newPage.byteBuffer().duplicate(); + } + + FlushOperation buildNetworkFlushOperation() { + int pageCount = pages.size(); + ByteBuffer[] byteBuffers = new ByteBuffer[pageCount]; + Page[] pagesToClose = new Page[pageCount]; + for (int i = 0; i < pageCount; ++i) { + Page page = pages.removeFirst(); + pagesToClose[i] = page; + byteBuffers[i] = page.byteBuffer(); + } + + return new FlushOperation(byteBuffers, (r, e) -> IOUtils.closeWhileHandlingException(pagesToClose)); + } + + boolean hasEncryptedBytesToFlush() { + return pages.isEmpty() == false; + } + + @Override + public void close() { + IOUtils.closeWhileHandlingException(pages); + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java index 9e0da2518835d..8ecba16fa460d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java @@ -22,6 +22,7 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.threadpool.ThreadPool; @@ -92,9 +93,9 @@ private SecurityHttpChannelFactory() { @Override public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioHttpChannel httpChannel = new NioHttpChannel(channel); - Supplier pageSupplier = () -> { + Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this, handlingSettings, corsConfig); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index dbffeaec58e50..903fec52e9e9b 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -21,6 +21,7 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.threadpool.ThreadPool; @@ -155,9 +156,9 @@ private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String pr @Override public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { + Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 0870124022850..893af2140b9b0 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.nio.BytesWriteHandler; +import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.FlushReadyWrite; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; @@ -28,6 +29,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -49,7 +51,6 @@ public class SSLChannelContextTests extends ESTestCase { private Consumer exceptionHandler; private SSLDriver sslDriver; private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14); - private ByteBuffer writeBuffer = ByteBuffer.allocate(1 << 14); private int messageLength; @Before @@ -73,7 +74,6 @@ public void init() { when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getTaskScheduler()).thenReturn(nioTimer); when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer); - when(sslDriver.getNetworkWriteBuffer()).thenReturn(writeBuffer); ByteBuffer buffer = ByteBuffer.allocate(1 << 14); when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> { buffer.clear(); @@ -85,7 +85,7 @@ public void testSuccessfulRead() throws IOException { byte[] bytes = createMessage(messageLength); when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0); @@ -100,7 +100,7 @@ public void testMultipleReadsConsumed() throws IOException { byte[] bytes = createMessage(messageLength * 2); when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0); @@ -115,7 +115,7 @@ public void testPartialRead() throws IOException { byte[] bytes = createMessage(messageLength); when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); when(readConsumer.apply(channelBuffer)).thenReturn(0); @@ -173,7 +173,6 @@ public void testSSLDriverClosedOnClose() throws IOException { public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() { when(sslDriver.readyForApplicationWrites()).thenReturn(false); - when(sslDriver.hasFlushPending()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(false); context.queueWriteOperation(mock(FlushReadyWrite.class)); @@ -181,25 +180,25 @@ public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() { assertFalse(context.readyForFlush()); } - public void testPendingFlushMeansWriteInterested() { - when(sslDriver.readyForApplicationWrites()).thenReturn(randomBoolean()); - when(sslDriver.hasFlushPending()).thenReturn(true); - when(sslDriver.needsNonApplicationWrite()).thenReturn(false); + public void testPendingEncryptedFlushMeansWriteInterested() throws Exception { + when(sslDriver.readyForApplicationWrites()).thenReturn(false); + when(sslDriver.needsNonApplicationWrite()).thenReturn(true, false); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + // Call will put bytes in buffer to flush + context.flushChannel(); assertTrue(context.readyForFlush()); } public void testNeedsNonAppWritesMeansWriteInterested() { when(sslDriver.readyForApplicationWrites()).thenReturn(false); - when(sslDriver.hasFlushPending()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); assertTrue(context.readyForFlush()); } - public void testNotWritesInterestInAppMode() { + public void testNoNonAppWriteInterestInAppMode() { when(sslDriver.readyForApplicationWrites()).thenReturn(true); - when(sslDriver.hasFlushPending()).thenReturn(false); assertFalse(context.readyForFlush()); @@ -207,66 +206,68 @@ public void testNotWritesInterestInAppMode() { } public void testFirstFlushMustFinishForWriteToContinue() throws Exception { - when(sslDriver.hasFlushPending()).thenReturn(true, true); when(sslDriver.readyForApplicationWrites()).thenReturn(false); + when(sslDriver.needsNonApplicationWrite()).thenReturn(true); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + // First call will put bytes in buffer to flush + context.flushChannel(); + assertTrue(context.readyForFlush()); + // Second call will will not continue generating non-app bytes because they still need to be flushed context.flushChannel(); + assertTrue(context.readyForFlush()); - verify(sslDriver, times(0)).nonApplicationWrite(); + verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class)); } public void testNonAppWrites() throws Exception { - when(sslDriver.hasFlushPending()).thenReturn(false, false, true, false, true); when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, false); when(sslDriver.readyForApplicationWrites()).thenReturn(false); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(1); context.flushChannel(); - verify(sslDriver, times(2)).nonApplicationWrite(); + verify(sslDriver, times(2)).nonApplicationWrite(any(SSLOutboundBuffer.class)); verify(rawChannel, times(2)).write(same(selector.getIoBuffer())); } public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception { - when(sslDriver.hasFlushPending()).thenReturn(false, false, true, true); - when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, true, true); + when(sslDriver.needsNonApplicationWrite()).thenReturn(true); when(sslDriver.readyForApplicationWrites()).thenReturn(false); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(0); context.flushChannel(); - verify(sslDriver, times(1)).nonApplicationWrite(); + verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class)); verify(rawChannel, times(1)).write(same(selector.getIoBuffer())); } public void testQueuedWriteIsFlushedInFlushCall() throws Exception { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener); context.queueWriteOperation(flushOperation); - when(flushOperation.getBuffersToWrite()).thenReturn(buffers); - when(flushOperation.getListener()).thenReturn(listener); - when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - when(sslDriver.applicationWrite(buffers)).thenReturn(10); - when(flushOperation.isFullyFlushed()).thenReturn(false,true); + doAnswer(getWriteAnswer(10, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + + when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(10); context.flushChannel(); - verify(flushOperation).incrementIndex(10); verify(rawChannel, times(1)).write(same(selector.getIoBuffer())); verify(selector).executeListener(listener, null); assertFalse(context.readyForFlush()); } public void testPartialFlush() throws IOException { - ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + ByteBuffer[] buffers = {ByteBuffer.allocate(5)}; + FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener); context.queueWriteOperation(flushOperation); - when(flushOperation.getBuffersToWrite()).thenReturn(buffers); - when(flushOperation.getListener()).thenReturn(listener); - when(sslDriver.hasFlushPending()).thenReturn(false, false, true); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - when(sslDriver.applicationWrite(buffers)).thenReturn(5); - when(flushOperation.isFullyFlushed()).thenReturn(false, false); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(4); context.flushChannel(); verify(rawChannel, times(1)).write(same(selector.getIoBuffer())); @@ -279,24 +280,16 @@ public void testMultipleWritesPartialFlushes() throws IOException { BiConsumer listener2 = mock(BiConsumer.class); ByteBuffer[] buffers1 = {ByteBuffer.allocate(10)}; ByteBuffer[] buffers2 = {ByteBuffer.allocate(5)}; - FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class); - FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class); - when(flushOperation1.getBuffersToWrite()).thenReturn(buffers1); - when(flushOperation2.getBuffersToWrite()).thenReturn(buffers2); - when(flushOperation1.getListener()).thenReturn(listener); - when(flushOperation2.getListener()).thenReturn(listener2); + FlushReadyWrite flushOperation1 = new FlushReadyWrite(context, buffers1, listener); + FlushReadyWrite flushOperation2 = new FlushReadyWrite(context, buffers2, listener2); context.queueWriteOperation(flushOperation1); context.queueWriteOperation(flushOperation2); - when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false, false, true); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - when(sslDriver.applicationWrite(buffers1)).thenReturn(5, 5); - when(sslDriver.applicationWrite(buffers2)).thenReturn(3); - when(flushOperation1.isFullyFlushed()).thenReturn(false, false, true); - when(flushOperation2.isFullyFlushed()).thenReturn(false); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(any(FlushOperation.class), any(SSLOutboundBuffer.class)); + when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(5, 5, 2); context.flushChannel(); - verify(flushOperation1, times(2)).incrementIndex(5); verify(rawChannel, times(3)).write(same(selector.getIoBuffer())); verify(selector).executeListener(listener, null); verify(selector, times(0)).executeListener(listener2, null); @@ -304,29 +297,27 @@ public void testMultipleWritesPartialFlushes() throws IOException { } public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { - ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); + ByteBuffer[] buffers = {ByteBuffer.allocate(5)}; + FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener); context.queueWriteOperation(flushOperation); IOException exception = new IOException(); - when(flushOperation.getBuffersToWrite()).thenReturn(buffers); - when(flushOperation.getListener()).thenReturn(listener); - when(sslDriver.hasFlushPending()).thenReturn(false, false); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - when(sslDriver.applicationWrite(buffers)).thenReturn(5); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception); - when(flushOperation.isFullyFlushed()).thenReturn(false); expectThrows(IOException.class, () -> context.flushChannel()); - verify(flushOperation).incrementIndex(5); verify(selector).executeFailedListener(listener, exception); assertFalse(context.readyForFlush()); } public void testWriteIOExceptionMeansChannelReadyToClose() throws Exception { - when(sslDriver.hasFlushPending()).thenReturn(true); - when(sslDriver.needsNonApplicationWrite()).thenReturn(true); when(sslDriver.readyForApplicationWrites()).thenReturn(false); + when(sslDriver.needsNonApplicationWrite()).thenReturn(true); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + + context.flushChannel(); + when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException()); assertFalse(context.selectorShouldClose()); @@ -413,7 +404,27 @@ public void testRegisterInitiatesDriver() throws IOException { } } - private Answer getAnswerForBytes(byte[] bytes) { + private Answer getWriteAnswer(int bytesToEncrypt, boolean isApp) { + return invocationOnMock -> { + SSLOutboundBuffer outboundBuffer; + if (isApp) { + outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[1]; + } else { + outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[0]; + } + ByteBuffer byteBuffer = outboundBuffer.nextWriteBuffer(bytesToEncrypt + 1); + for (int i = 0; i < bytesToEncrypt; ++i) { + byteBuffer.put((byte) i); + } + outboundBuffer.incrementEncryptedBytes(bytesToEncrypt); + if (isApp) { + ((FlushOperation) invocationOnMock.getArguments()[0]).incrementIndex(bytesToEncrypt); + } + return bytesToEncrypt; + }; + } + + private Answer getReadAnswerForBytes(byte[] bytes) { return invocationOnMock -> { InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0]; buffer.ensureCapacity(buffer.getIndex() + bytes.length); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java index b1d39ddc6ac9f..4b86d3223b061 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java @@ -6,7 +6,9 @@ package org.elasticsearch.xpack.security.transport.nio; import org.elasticsearch.bootstrap.JavaVersion; +import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.Page; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ssl.CertParsingUtils; import org.elasticsearch.xpack.core.ssl.PemUtils; @@ -28,8 +30,7 @@ public class SSLDriverTests extends ESTestCase { - private final Supplier pageSupplier = - () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), () -> {}); + private final Supplier pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), () -> {}); private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier); private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier); private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier); @@ -141,10 +142,6 @@ public void testHandshakeFailureBecauseProtocolMismatch() throws Exception { boolean expectedMessage = oldExpected.equals(sslException.getMessage()) || jdk11Expected.equals(sslException.getMessage()); assertTrue("Unexpected exception message: " + sslException.getMessage(), expectedMessage); - // In JDK11 we need an non-application write - if (serverDriver.needsNonApplicationWrite()) { - serverDriver.nonApplicationWrite(); - } // Prior to JDK11 we still need to send a close alert if (serverDriver.isClosed() == false) { failedCloseAlert(serverDriver, clientDriver, Arrays.asList("Received fatal alert: protocol_version", @@ -166,10 +163,7 @@ public void testHandshakeFailureBecauseNoCiphers() throws Exception { SSLDriver serverDriver = getDriver(serverEngine, false); expectThrows(SSLException.class, () -> handshake(clientDriver, serverDriver)); - // In JDK11 we need an non-application write - if (serverDriver.needsNonApplicationWrite()) { - serverDriver.nonApplicationWrite(); - } + // Prior to JDK11 we still need to send a close alert if (serverDriver.isClosed() == false) { List messages = Arrays.asList("Received fatal alert: handshake_failure", @@ -192,8 +186,6 @@ public void testCloseDuringHandshakeJDK11() throws Exception { sendHandshakeMessages(clientDriver, serverDriver); sendHandshakeMessages(serverDriver, clientDriver); - sendData(clientDriver, serverDriver); - assertTrue(clientDriver.isHandshaking()); assertTrue(serverDriver.isHandshaking()); @@ -227,8 +219,6 @@ public void testCloseDuringHandshakePreJDK11() throws Exception { sendHandshakeMessages(clientDriver, serverDriver); sendHandshakeMessages(serverDriver, clientDriver); - sendData(clientDriver, serverDriver); - assertTrue(clientDriver.isHandshaking()); assertTrue(serverDriver.isHandshaking()); @@ -306,12 +296,12 @@ private void normalClose(SSLDriver sendDriver, SSLDriver receiveDriver) throws I } private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException { - while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) { - if (sendDriver.hasFlushPending() == false) { - sendDriver.nonApplicationWrite(); - } - if (sendDriver.hasFlushPending()) { - sendData(sendDriver, receiveDriver, true); + SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { + if (outboundBuffer.hasEncryptedBytesToFlush()) { + sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); + } else { + sendDriver.nonApplicationWrite(outboundBuffer); } } } @@ -326,7 +316,7 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i serverDriver.init(); } - assertTrue(clientDriver.needsNonApplicationWrite() || clientDriver.hasFlushPending()); + assertTrue(clientDriver.needsNonApplicationWrite()); assertFalse(serverDriver.needsNonApplicationWrite()); sendHandshakeMessages(clientDriver, serverDriver); @@ -350,58 +340,51 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i } private void sendHandshakeMessages(SSLDriver sendDriver, SSLDriver receiveDriver) throws IOException { - assertTrue(sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()); + assertTrue(sendDriver.needsNonApplicationWrite()); - while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) { - if (sendDriver.hasFlushPending() == false) { - sendDriver.nonApplicationWrite(); - } - if (sendDriver.isHandshaking()) { - assertTrue(sendDriver.hasFlushPending()); - sendData(sendDriver, receiveDriver); - assertFalse(sendDriver.hasFlushPending()); + SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + + while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { + if (outboundBuffer.hasEncryptedBytesToFlush()) { + sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); receiveDriver.read(genericBuffer); + } else { + sendDriver.nonApplicationWrite(outboundBuffer); } } if (receiveDriver.isHandshaking()) { - assertTrue(receiveDriver.needsNonApplicationWrite() || receiveDriver.hasFlushPending()); + assertTrue(receiveDriver.needsNonApplicationWrite()); } } private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException { - assertFalse(sendDriver.needsNonApplicationWrite()); int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum(); + SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {}); int bytesEncrypted = 0; while (bytesToEncrypt > bytesEncrypted) { - bytesEncrypted += sendDriver.applicationWrite(message); - sendData(sendDriver, receiveDriver); + bytesEncrypted += sendDriver.write(flushOperation, outboundBuffer); + sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); } } - private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver) { - sendData(sendDriver, receiveDriver, randomBoolean()); - } - - private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver, boolean partial) { - ByteBuffer writeBuffer = sendDriver.getNetworkWriteBuffer(); + private void sendData(FlushOperation flushOperation, SSLDriver receiveDriver) { ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer(); - if (partial) { - int initialLimit = writeBuffer.limit(); - int bytesToWrite = writeBuffer.remaining() / (randomInt(2) + 2); - writeBuffer.limit(writeBuffer.position() + bytesToWrite); - readBuffer.put(writeBuffer); - writeBuffer.limit(initialLimit); - assertTrue(sendDriver.hasFlushPending()); - readBuffer.put(writeBuffer); - assertFalse(sendDriver.hasFlushPending()); + ByteBuffer[] writeBuffers = flushOperation.getBuffersToWrite(); + int bytesToEncrypt = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum(); + assert bytesToEncrypt < readBuffer.capacity() : "Flush operation must be less that read buffer"; + assert writeBuffers.length > 0 : "No write buffers"; - } else { + for (ByteBuffer writeBuffer : writeBuffers) { + int written = writeBuffer.remaining(); readBuffer.put(writeBuffer); - assertFalse(sendDriver.hasFlushPending()); + flushOperation.incrementIndex(written); } + + assertTrue(flushOperation.isFullyFlushed()); } private SSLDriver getDriver(SSLEngine engine, boolean isClient) {