diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java index d220438fc3170..3ab99a4eb343d 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java @@ -9,12 +9,26 @@ package org.elasticsearch.http.netty4; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpVersion; + import org.apache.lucene.util.BytesRef; import org.elasticsearch.ESNetty4IntegTestCase; import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Request; -import org.elasticsearch.client.Response; -import org.elasticsearch.client.ResponseListener; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -22,6 +36,7 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.bytes.ZeroBytesReference; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -34,7 +49,11 @@ import org.elasticsearch.core.AbstractRefCounted; import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.http.HttpInfo; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.BaseRestHandler; @@ -46,25 +65,33 @@ import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.transport.netty4.Netty4Utils; +import org.elasticsearch.transport.netty4.Netty4WriteThrottlingHandler; +import org.elasticsearch.transport.netty4.NettyAllocator; import java.io.IOException; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Iterator; import java.util.List; -import java.util.concurrent.CancellationException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.Predicate; import java.util.function.Supplier; import static org.elasticsearch.rest.RestRequest.Method.GET; import static org.elasticsearch.rest.RestResponse.TEXT_CONTENT_TYPE; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; public class Netty4ChunkedEncodingIT extends ESNetty4IntegTestCase { + private static final Logger logger = LogManager.getLogger(Netty4ChunkedEncodingIT.class); + @Override protected Collection> nodePlugins() { return CollectionUtils.concatLists(List.of(YieldsChunksPlugin.class), super.nodePlugins()); @@ -106,25 +133,85 @@ private static void getAndCheckBodyContents(String route, String expectedBody) t } public void testClientCancellation() { - try (var ignored = withResourceTracker()) { - final var cancellable = getRestClient().performRequestAsync( - new Request("GET", YieldsChunksPlugin.INFINITE_ROUTE), - new ResponseListener() { - @Override - public void onSuccess(Response response) { - fail("should not complete"); - } + final var releasables = new ArrayList(4); + + try { + releasables.add(withResourceTracker()); + + final var eventLoopGroup = new NioEventLoopGroup(1); + releasables.add(() -> eventLoopGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).awaitUninterruptibly()); + + final var gracefulClose = randomBoolean(); + final var chunkSizeBytes = between(1, Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE * 2); // sometimes write in slices + final var closeAfterBytes = between(0, chunkSizeBytes * 5); + final var chunkDelayMillis = randomBoolean() ? 0 : between(10, 100); + + final var clientBootstrap = new Bootstrap().channel(NettyAllocator.getChannelType()) + .option(ChannelOption.ALLOCATOR, NettyAllocator.getAllocator()) + .group(eventLoopGroup) + .handler(new ChannelInitializer() { @Override - public void onFailure(Exception exception) { - assertThat(exception, instanceOf(CancellationException.class)); + protected void initChannel(SocketChannel ch) { + if (gracefulClose == false) { + ch.config().setOption(ChannelOption.SO_LINGER, 0); // RST on close + } + + ch.pipeline().addLast(new HttpClientCodec()).addLast(new SimpleChannelInboundHandler() { + + private long bytesReceived; + + @Override + protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) { + if (msg instanceof HttpContent hc) { + bytesReceived += hc.content().readableBytes(); + if (bytesReceived > closeAfterBytes) { + ctx.close(); + } + } else { + assertEquals(200, asInstanceOf(HttpResponse.class, msg).status().code()); + assertEquals(0L, bytesReceived); + } + } + }); } - } + }); + + final var channel = clientBootstrap.connect( + randomFrom( + clusterAdmin().prepareNodesInfo() + .get() + .getNodes() + .stream() + .flatMap(n -> Arrays.stream(n.getInfo(HttpInfo.class).address().boundAddresses())) + .toList() + ).address() + ).syncUninterruptibly().channel(); + releasables.add(() -> channel.close().syncUninterruptibly()); + + logger.info("--> using client channel [{}] with gracefulClose={}", channel, gracefulClose); + + final var request = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, + HttpMethod.GET, + Strings.format( + "%s?%s=%d&%s=%d", + YieldsChunksPlugin.INFINITE_ROUTE, + YieldsChunksPlugin.INFINITE_ROUTE_SIZE_BYTES_PARAM, + chunkSizeBytes, + YieldsChunksPlugin.INFINITE_ROUTE_DELAY_MILLIS_PARAM, + chunkDelayMillis + ) ); - if (randomBoolean()) { - safeSleep(scaledRandomIntBetween(10, 500)); - } - cancellable.cancel(); + request.headers().set(HttpHeaderNames.HOST, "localhost"); + channel.writeAndFlush(request); + + logger.info("--> client waiting"); + safeAwait(l -> Netty4Utils.addListener(channel.closeFuture(), ignoredFuture -> l.onResponse(null))); + logger.info("--> client channel closed"); + } finally { + Collections.reverse(releasables); + Releasables.close(releasables); } } @@ -148,6 +235,8 @@ public static class YieldsChunksPlugin extends Plugin implements ActionPlugin { static final String CHUNKS_ROUTE = "/_test/yields_chunks"; static final String EMPTY_ROUTE = "/_test/yields_only_empty_chunks"; static final String INFINITE_ROUTE = "/_test/yields_infinite_chunks"; + static final String INFINITE_ROUTE_SIZE_BYTES_PARAM = "chunk_size_bytes"; + static final String INFINITE_ROUTE_DELAY_MILLIS_PARAM = "chunk_delay_millis"; private static Iterator emptyChunks() { return Iterators.forRange(0, between(0, 2), i -> BytesArray.EMPTY); @@ -225,9 +314,12 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + final var chunkSize = request.paramAsInt(INFINITE_ROUTE_SIZE_BYTES_PARAM, -1); + assertThat(chunkSize, greaterThanOrEqualTo(1)); + final var chunk = new ZeroBytesReference(chunkSize); + final var chunkDelayMillis = request.paramAsInt(INFINITE_ROUTE_DELAY_MILLIS_PARAM, -1); + assertThat(chunkDelayMillis, greaterThanOrEqualTo(0)); return channel -> sendChunksResponse(channel, new Iterator<>() { - private static final BytesReference CHUNK = new BytesArray("CHUNK\n"); - @Override public boolean hasNext() { return true; @@ -235,7 +327,11 @@ public boolean hasNext() { @Override public BytesReference next() { - return CHUNK; + logger.info("--> yielding chunk of size [{}]", chunkSize); + if (chunkDelayMillis > 0) { + safeSleep(chunkDelayMillis); + } + return chunk; } }); }