diff --git a/modules/reindex/src/main/java/org/elasticsearch/reindex/BulkByPaginatedSearchParallelizationHelper.java b/modules/reindex/src/main/java/org/elasticsearch/reindex/BulkByPaginatedSearchParallelizationHelper.java index 62515cd334dc2..5e5866ed8ace0 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/reindex/BulkByPaginatedSearchParallelizationHelper.java +++ b/modules/reindex/src/main/java/org/elasticsearch/reindex/BulkByPaginatedSearchParallelizationHelper.java @@ -9,6 +9,7 @@ package org.elasticsearch.reindex; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest; @@ -17,6 +18,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.Index; import org.elasticsearch.index.mapper.IdFieldMapper; import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest; @@ -35,6 +37,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Consumer; import java.util.stream.Collectors; import static org.elasticsearch.index.reindex.AbstractBulkByScrollRequest.AUTO_SLICES; @@ -73,7 +76,9 @@ static > void startSlicedAc task, request, client, - listener.delegateFailure((l, v) -> executeSlicedAction(task, request, action, l, client, node, workerAction)) + listener.delegateFailure( + (l, v) -> executeSlicedAction(task, request, action, l, client, node, null, version -> workerAction.run()) + ) ); } @@ -81,11 +86,14 @@ static > void startSlicedAc * Takes an action and a {@link BulkByScrollTask} and runs it with regard to whether this task is a * leader or worker. * - * If this task is a worker, the worker action in the given {@link Runnable} will be started on the local - * node. If the task is a leader (i.e. the number of slices is more than 1), then a subrequest will be - * created for each slice and sent. + * If this task is a worker, the worker action is invoked with the given {@code remoteVersion} (may be null + * for local reindex). If the task is a leader (i.e. the number of slices is more than 1), then a subrequest + * will be created for each slice and sent. * * This method can only be called after the task state is initialized {@link #initTaskState}. + * + * @param remoteVersion the version of the remote cluster when reindexing from remote, or null for local reindex + * @param workerAction invoked when this task is a worker, with the remote version (or null) */ static > void executeSlicedAction( BulkByScrollTask task, @@ -94,12 +102,13 @@ static > void executeSliced ActionListener listener, Client client, DiscoveryNode node, - Runnable workerAction + @Nullable Version remoteVersion, + Consumer workerAction ) { if (task.isLeader()) { sendSubRequests(client, action, node.getId(), task, request, listener); } else if (task.isWorker()) { - workerAction.run(); + workerAction.accept(remoteVersion); } else { throw new AssertionError("Task should have been initialized at this point."); } diff --git a/modules/reindex/src/main/java/org/elasticsearch/reindex/Reindexer.java b/modules/reindex/src/main/java/org/elasticsearch/reindex/Reindexer.java index 3a16870ef6f1d..b047963afcfed 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/reindex/Reindexer.java +++ b/modules/reindex/src/main/java/org/elasticsearch/reindex/Reindexer.java @@ -19,6 +19,7 @@ import org.apache.http.message.BasicHeader; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.DocWriteRequest; @@ -53,11 +54,13 @@ import org.elasticsearch.index.reindex.PaginatedHitSource; import org.elasticsearch.index.reindex.ReindexAction; import org.elasticsearch.index.reindex.ReindexRequest; +import org.elasticsearch.index.reindex.RejectAwareActionListener; import org.elasticsearch.index.reindex.RemoteInfo; import org.elasticsearch.index.reindex.ResumeBulkByScrollRequest; import org.elasticsearch.index.reindex.ResumeBulkByScrollResponse; import org.elasticsearch.index.reindex.ResumeReindexAction; import org.elasticsearch.index.reindex.WorkerBulkByScrollTaskState; +import org.elasticsearch.reindex.remote.RemoteReindexingUtils; import org.elasticsearch.reindex.remote.RemoteScrollablePaginatedHitSource; import org.elasticsearch.script.CtxMap; import org.elasticsearch.script.ReindexMetadata; @@ -83,12 +86,15 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.LongSupplier; import java.util.function.Supplier; import static java.util.Collections.emptyList; import static java.util.Collections.synchronizedList; +import static org.elasticsearch.common.BackoffPolicy.exponentialBackoff; import static org.elasticsearch.index.VersionType.INTERNAL; +import static org.elasticsearch.reindex.ReindexPlugin.REINDEX_PIT_SEARCH_ENABLED; public class Reindexer { @@ -143,33 +149,113 @@ public void execute(BulkByScrollTask task, ReindexRequest request, Client bulkCl // for update-by-query and delete-by-query final ActionListener listenerWithRelocations = listenerWithRelocations(task, request, listener); - BulkByPaginatedSearchParallelizationHelper.executeSlicedAction( - task, - request, - ReindexAction.INSTANCE, - listenerWithRelocations, - client, - clusterService.localNode(), - () -> { - ParentTaskAssigningClient assigningClient = new ParentTaskAssigningClient(client, clusterService.localNode(), task); - ParentTaskAssigningClient assigningBulkClient = new ParentTaskAssigningClient(bulkClient, clusterService.localNode(), task); - AsyncIndexBySearchAction searchAction = new AsyncIndexBySearchAction( - task, - logger, - assigningClient, - assigningBulkClient, - threadPool, - scriptService, - projectResolver.getProjectState(clusterService.state()), - reindexSslConfig, - request, - workerListenerWithRelocationAndMetrics(listenerWithRelocations, startTime, request.getRemoteInfo() != null) + Consumer workerAction = remoteVersion -> { + ParentTaskAssigningClient assigningClient = new ParentTaskAssigningClient(client, clusterService.localNode(), task); + ParentTaskAssigningClient assigningBulkClient = new ParentTaskAssigningClient(bulkClient, clusterService.localNode(), task); + AsyncIndexBySearchAction searchAction = new AsyncIndexBySearchAction( + task, + logger, + assigningClient, + assigningBulkClient, + threadPool, + scriptService, + projectResolver.getProjectState(clusterService.state()), + reindexSslConfig, + request, + workerListenerWithRelocationAndMetrics(listenerWithRelocations, startTime, request.getRemoteInfo() != null), + remoteVersion + ); + searchAction.start(); + }; + + /** + * If this is a request to reindex from remote, then we need to determine the remote version prior to execution + * NB {@link ReindexRequest} forbids remote requests and slices > 1, so we're guaranteed to be running on the only slice + */ + if (REINDEX_PIT_SEARCH_ENABLED && request.getRemoteInfo() != null) { + lookupRemoteVersionAndExecute(task, request, listenerWithRelocations, workerAction); + } else { + BulkByPaginatedSearchParallelizationHelper.executeSlicedAction( + task, + request, + ReindexAction.INSTANCE, + listenerWithRelocations, + client, + clusterService.localNode(), + null, + workerAction + ); + } + } + + /** + * Looks up the remote cluster version when reindexing from a remote source, then runs the sliced action with that version. + * The RestClient used for the lookup is closed after the callback; closing must happen on a thread other than the + * RestClient's own thread pool to avoid shutdown failures. + */ + private void lookupRemoteVersionAndExecute( + BulkByScrollTask task, + ReindexRequest request, + ActionListener listenerWithRelocations, + Consumer workerAction + ) { + RemoteInfo remoteInfo = request.getRemoteInfo(); + assert reindexSslConfig != null : "Reindex ssl config must be set"; + RestClient restClient = buildRestClient(remoteInfo, reindexSslConfig, task.getId(), synchronizedList(new ArrayList<>())); + RejectAwareActionListener rejectAwareListener = new RejectAwareActionListener<>() { + @Override + public void onResponse(Version version) { + closeRestClientAndRun( + restClient, + () -> BulkByPaginatedSearchParallelizationHelper.executeSlicedAction( + task, + request, + ReindexAction.INSTANCE, + listenerWithRelocations, + client, + clusterService.localNode(), + version, + workerAction + ) ); - searchAction.start(); } + + @Override + public void onFailure(Exception e) { + closeRestClientAndRun(restClient, () -> listenerWithRelocations.onFailure(e)); + } + + @Override + public void onRejection(Exception e) { + closeRestClientAndRun(restClient, () -> listenerWithRelocations.onFailure(e)); + } + }; + RemoteReindexingUtils.lookupRemoteVersionWithRetries( + logger, + exponentialBackoff(request.getRetryBackoffInitialTime(), request.getMaxRetries()), + threadPool, + restClient, + // TODO - Do we want to pass in a countRetry runnable here to count the number of times we retry? + // https://github.com/elastic/elasticsearch-team/issues/2382 + rejectAwareListener ); } + /** + * Closes the RestClient on the generic thread pool (to avoid closing from the client's own thread), then runs the given action. + */ + private void closeRestClientAndRun(RestClient restClient, Runnable onCompletion) { + threadPool.generic().submit(() -> { + try { + restClient.close(); + } catch (IOException e) { + logger.warn("Failed to close RestClient after version lookup", e); + } finally { + onCompletion.run(); + } + }); + } + /** Wraps the listener with metrics tracking and relocation handling (if applicable). Visible for testing. */ ActionListener workerListenerWithRelocationAndMetrics( ActionListener potentiallyWrappedRelocationListener, @@ -413,6 +499,11 @@ static class AsyncIndexBySearchAction extends AbstractAsyncBulkByScrollAction createdThreads = emptyList(); + /** + * Version of the remote cluster when reindexing from remote, or null when reindexing locally. + */ + private final Version remoteVersion; + AsyncIndexBySearchAction( BulkByScrollTask task, Logger logger, @@ -423,7 +514,8 @@ static class AsyncIndexBySearchAction extends AbstractAsyncBulkByScrollAction listener + ActionListener listener, + @Nullable Version remoteVersion ) { super( task, @@ -444,6 +536,7 @@ static class AsyncIndexBySearchAction extends AbstractAsyncBulkByScrollAction listener, ThreadPool threadPool, RestClient client) { + execute(new Request("GET", "/"), MAIN_ACTION_PARSER, listener, threadPool, client); + } + + /** + * Looks up the remote cluster version with retries on rejection (e.g. 429 Too Many Requests). + * Matches the retry behavior used by {@link RemoteScrollablePaginatedHitSource} when it looks up the version. + * + * @param logger logger for retry messages + * @param backoffPolicy policy for delay between retries + * @param threadPool thread pool for scheduling retries + * @param client REST client for the remote cluster + * @param delegate receives the version on success or failure after all retries exhausted + */ + public static void lookupRemoteVersionWithRetries( + Logger logger, + BackoffPolicy backoffPolicy, + ThreadPool threadPool, + RestClient client, + RejectAwareActionListener delegate + ) { + RetryListener retryListener = new RetryListener<>(logger, threadPool, backoffPolicy, listener -> { + lookupRemoteVersion(listener, threadPool, client); + }, delegate); + lookupRemoteVersion(retryListener, threadPool, client); + } + + /** + * Performs an async HTTP request to the remote cluster, parses the response, and notifies the listener. + * Preserves thread context across the async callback. On 429 (Too Many Requests), invokes + * {@link RejectAwareActionListener#onRejection} so callers can retry; other failures invoke + * {@link RejectAwareActionListener#onFailure}. + * + * @param type of the parsed response + * @param request HTTP request to perform + * @param parser function to parse the response body into type T + * @param listener receives the parsed result, or failure/rejection + * @param threadPool thread pool for preserving thread context + * @param client REST client for the remote cluster + */ + static void execute( + Request request, + BiFunction parser, + RejectAwareActionListener listener, + ThreadPool threadPool, + RestClient client + ) { + // Preserve the thread context so headers survive after the call + Supplier contextSupplier = threadPool.getThreadContext().newRestorableContext(true); + try { + client.performRequestAsync(request, new ResponseListener() { + @Override + public void onSuccess(Response response) { + // Restore the thread context to get the precious headers + try (ThreadContext.StoredContext ctx = contextSupplier.get()) { + assert ctx != null; // eliminates compiler warning + T parsedResponse; + try { + HttpEntity responseEntity = response.getEntity(); + InputStream content = responseEntity.getContent(); + XContentType xContentType = null; + if (responseEntity.getContentType() != null) { + final String mimeType = ContentType.parse(responseEntity.getContentType().getValue()).getMimeType(); + xContentType = XContentType.fromMediaType(mimeType); + } + if (xContentType == null) { + try { + throw new ElasticsearchException( + "Response didn't include Content-Type: " + bodyMessage(response.getEntity()) + ); + } catch (IOException e) { + ElasticsearchException ee = new ElasticsearchException("Error extracting body from response"); + ee.addSuppressed(e); + throw ee; + } + } + // EMPTY is safe here because we don't call namedObject + try ( + XContentParser xContentParser = xContentType.xContent() + .createParser( + XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE), + content + ) + ) { + parsedResponse = parser.apply(xContentParser, xContentType); + } catch (XContentParseException e) { + /* Because we're streaming the response we can't get a copy of it here. The best we can do is hint that it + * is totally wrong and we're probably not talking to Elasticsearch. */ + throw new ElasticsearchException( + "Error parsing the response, remote is likely not an Elasticsearch instance", + e + ); + } + } catch (IOException e) { + throw new ElasticsearchException( + "Error deserializing response, remote is likely not an Elasticsearch instance", + e + ); + } + listener.onResponse(parsedResponse); + } + } + + @Override + public void onFailure(Exception e) { + try (ThreadContext.StoredContext ctx = contextSupplier.get()) { + assert ctx != null; // eliminates compiler warning + if (e instanceof ResponseException re) { + int statusCode = re.getResponse().getStatusLine().getStatusCode(); + e = wrapExceptionToPreserveStatus(statusCode, re.getResponse().getEntity(), re); + if (RestStatus.TOO_MANY_REQUESTS.getStatus() == statusCode) { + listener.onRejection(e); + return; + } + } else if (e instanceof ContentTooLongException) { + e = new IllegalArgumentException( + "Remote responded with a chunk that was too large. Use a smaller batch size.", + e + ); + } + listener.onFailure(e); + } + } + }); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Wrap the ResponseException in an exception that'll preserve its status code if possible, so we can send it back to the user. We might + * not have a constant for the status code, so in that case, we just use 500 instead. We also extract make sure to include the response + * body in the message so the user can figure out *why* the remote Elasticsearch service threw the error back to us. + */ + static ElasticsearchStatusException wrapExceptionToPreserveStatus(int statusCode, @Nullable HttpEntity entity, Exception cause) { + RestStatus status = RestStatus.fromCode(statusCode); + String messagePrefix = ""; + if (status == null) { + messagePrefix = "Couldn't extract status [" + statusCode + "]. "; + status = RestStatus.INTERNAL_SERVER_ERROR; + } + try { + return new ElasticsearchStatusException(messagePrefix + bodyMessage(entity), status, cause); + } catch (IOException ioe) { + ElasticsearchStatusException e = new ElasticsearchStatusException(messagePrefix + "Failed to extract body.", status, cause); + e.addSuppressed(ioe); + return e; + } + } + + /** + * Extracts a readable string from an HTTP entity for use in error messages. + * + * @param entity HTTP entity, or null + * @return "No error body." if entity is null, otherwise "body=" + entity content + * @throws IOException if reading the entity fails + */ + static String bodyMessage(@Nullable HttpEntity entity) throws IOException { + if (entity == null) { + return "No error body."; + } else { + return "body=" + EntityUtils.toString(entity); + } + } +} diff --git a/modules/reindex/src/main/java/org/elasticsearch/reindex/remote/RemoteScrollablePaginatedHitSource.java b/modules/reindex/src/main/java/org/elasticsearch/reindex/remote/RemoteScrollablePaginatedHitSource.java index 1b3b6e89d7ee2..cf5dd07e111ec 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/reindex/remote/RemoteScrollablePaginatedHitSource.java +++ b/modules/reindex/src/main/java/org/elasticsearch/reindex/remote/RemoteScrollablePaginatedHitSource.java @@ -9,25 +9,16 @@ package org.elasticsearch.reindex.remote; -import org.apache.http.ContentTooLongException; -import org.apache.http.HttpEntity; -import org.apache.http.entity.ContentType; -import org.apache.http.util.EntityUtils; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.util.Supplier; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.client.Request; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.ResponseListener; import org.elasticsearch.client.RestClient; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.BackoffPolicy; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.reindex.PaginatedHitSource; @@ -35,23 +26,17 @@ import org.elasticsearch.index.reindex.RemoteInfo; import org.elasticsearch.index.reindex.ResumeInfo.ScrollWorkerResumeInfo; import org.elasticsearch.index.reindex.ResumeInfo.WorkerResumeInfo; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import java.io.IOException; -import java.io.InputStream; import java.util.Optional; -import java.util.function.BiFunction; import java.util.function.Consumer; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.core.TimeValue.timeValueMillis; import static org.elasticsearch.core.TimeValue.timeValueNanos; -import static org.elasticsearch.reindex.remote.RemoteResponseParsers.MAIN_ACTION_PARSER; +import static org.elasticsearch.reindex.remote.RemoteReindexingUtils.execute; +import static org.elasticsearch.reindex.remote.RemoteReindexingUtils.lookupRemoteVersion; import static org.elasticsearch.reindex.remote.RemoteResponseParsers.RESPONSE_PARSER; /** @@ -78,23 +63,51 @@ public RemoteScrollablePaginatedHitSource( RestClient client, RemoteInfo remoteInfo, SearchRequest searchRequest + ) { + this(logger, backoffPolicy, threadPool, countSearchRetry, onResponse, fail, client, remoteInfo, searchRequest, null); + } + + public RemoteScrollablePaginatedHitSource( + Logger logger, + BackoffPolicy backoffPolicy, + ThreadPool threadPool, + Runnable countSearchRetry, + Consumer onResponse, + Consumer fail, + RestClient client, + RemoteInfo remoteInfo, + SearchRequest searchRequest, + @Nullable Version initialRemoteVersion ) { super(logger, backoffPolicy, threadPool, countSearchRetry, onResponse, fail); this.remote = remoteInfo; this.searchRequest = searchRequest; this.client = client; + this.remoteVersion = initialRemoteVersion; } @Override protected void doStart(RejectAwareActionListener searchListener) { - lookupRemoteVersion(RejectAwareActionListener.withResponseHandler(searchListener, version -> { - remoteVersion = version; + if (remoteVersion != null) { execute( RemoteRequestBuilders.initialSearch(searchRequest, remote.getQuery(), remoteVersion), RESPONSE_PARSER, - RejectAwareActionListener.withResponseHandler(searchListener, r -> onStartResponse(searchListener, r)) + RejectAwareActionListener.withResponseHandler(searchListener, r -> onStartResponse(searchListener, r)), + threadPool, + client ); - })); + } else { + lookupRemoteVersion(RejectAwareActionListener.withResponseHandler(searchListener, version -> { + remoteVersion = version; + execute( + RemoteRequestBuilders.initialSearch(searchRequest, remote.getQuery(), remoteVersion), + RESPONSE_PARSER, + RejectAwareActionListener.withResponseHandler(searchListener, r -> onStartResponse(searchListener, r)), + threadPool, + client + ); + }), threadPool, client); + } } @Override @@ -110,11 +123,8 @@ public Optional remoteVersion() { return Optional.ofNullable(remoteVersion); } - void lookupRemoteVersion(RejectAwareActionListener listener) { - execute(new Request("GET", ""), MAIN_ACTION_PARSER, listener); - } - - private void onStartResponse(RejectAwareActionListener searchListener, Response response) { + // Exposed for testing + void onStartResponse(RejectAwareActionListener searchListener, Response response) { if (Strings.hasLength(response.getScrollId()) && response.getHits().isEmpty()) { logger.debug("First response looks like a scan response. Jumping right to the second. scroll=[{}]", response.getScrollId()); doStartNextScroll(response.getScrollId(), timeValueMillis(0), searchListener); @@ -126,7 +136,7 @@ private void onStartResponse(RejectAwareActionListener searchListener, @Override protected void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, RejectAwareActionListener searchListener) { TimeValue keepAlive = timeValueNanos(searchRequest.scroll().nanos() + extraKeepAlive.nanos()); - execute(RemoteRequestBuilders.scroll(scrollId, keepAlive, remoteVersion), RESPONSE_PARSER, searchListener); + execute(RemoteRequestBuilders.scroll(scrollId, keepAlive, remoteVersion), RESPONSE_PARSER, searchListener, threadPool, client); } @Override @@ -180,120 +190,4 @@ protected void cleanup(Runnable onCompletion) { } }); } - - private void execute( - Request request, - BiFunction parser, - RejectAwareActionListener listener - ) { - // Preserve the thread context so headers survive after the call - java.util.function.Supplier contextSupplier = threadPool.getThreadContext().newRestorableContext(true); - try { - client.performRequestAsync(request, new ResponseListener() { - @Override - public void onSuccess(org.elasticsearch.client.Response response) { - // Restore the thread context to get the precious headers - try (ThreadContext.StoredContext ctx = contextSupplier.get()) { - assert ctx != null; // eliminates compiler warning - T parsedResponse; - try { - HttpEntity responseEntity = response.getEntity(); - InputStream content = responseEntity.getContent(); - XContentType xContentType = null; - if (responseEntity.getContentType() != null) { - final String mimeType = ContentType.parse(responseEntity.getContentType().getValue()).getMimeType(); - xContentType = XContentType.fromMediaType(mimeType); - } - if (xContentType == null) { - try { - throw new ElasticsearchException( - "Response didn't include Content-Type: " + bodyMessage(response.getEntity()) - ); - } catch (IOException e) { - ElasticsearchException ee = new ElasticsearchException("Error extracting body from response"); - ee.addSuppressed(e); - throw ee; - } - } - // EMPTY is safe here because we don't call namedObject - try ( - XContentParser xContentParser = xContentType.xContent() - .createParser( - XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE), - content - ) - ) { - parsedResponse = parser.apply(xContentParser, xContentType); - } catch (XContentParseException e) { - /* Because we're streaming the response we can't get a copy of it here. The best we can do is hint that it - * is totally wrong and we're probably not talking to Elasticsearch. */ - throw new ElasticsearchException( - "Error parsing the response, remote is likely not an Elasticsearch instance", - e - ); - } - } catch (IOException e) { - throw new ElasticsearchException( - "Error deserializing response, remote is likely not an Elasticsearch instance", - e - ); - } - listener.onResponse(parsedResponse); - } - } - - @Override - public void onFailure(Exception e) { - try (ThreadContext.StoredContext ctx = contextSupplier.get()) { - assert ctx != null; // eliminates compiler warning - if (e instanceof ResponseException re) { - int statusCode = re.getResponse().getStatusLine().getStatusCode(); - e = wrapExceptionToPreserveStatus(statusCode, re.getResponse().getEntity(), re); - if (RestStatus.TOO_MANY_REQUESTS.getStatus() == statusCode) { - listener.onRejection(e); - return; - } - } else if (e instanceof ContentTooLongException) { - e = new IllegalArgumentException( - "Remote responded with a chunk that was too large. Use a smaller batch size.", - e - ); - } - listener.onFailure(e); - } - } - }); - } catch (Exception e) { - listener.onFailure(e); - } - } - - /** - * Wrap the ResponseException in an exception that'll preserve its status code if possible so we can send it back to the user. We might - * not have a constant for the status code so in that case we just use 500 instead. We also extract make sure to include the response - * body in the message so the user can figure out *why* the remote Elasticsearch service threw the error back to us. - */ - static ElasticsearchStatusException wrapExceptionToPreserveStatus(int statusCode, @Nullable HttpEntity entity, Exception cause) { - RestStatus status = RestStatus.fromCode(statusCode); - String messagePrefix = ""; - if (status == null) { - messagePrefix = "Couldn't extract status [" + statusCode + "]. "; - status = RestStatus.INTERNAL_SERVER_ERROR; - } - try { - return new ElasticsearchStatusException(messagePrefix + bodyMessage(entity), status, cause); - } catch (IOException ioe) { - ElasticsearchStatusException e = new ElasticsearchStatusException(messagePrefix + "Failed to extract body.", status, cause); - e.addSuppressed(ioe); - return e; - } - } - - private static String bodyMessage(@Nullable HttpEntity entity) throws IOException { - if (entity == null) { - return "No error body."; - } else { - return "body=" + EntityUtils.toString(entity); - } - } } diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/BulkByPaginatedSearchParallelizationHelperTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/BulkByPaginatedSearchParallelizationHelperTests.java index 8b5e3f644eb1f..dad4e9885e0e4 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/BulkByPaginatedSearchParallelizationHelperTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/BulkByPaginatedSearchParallelizationHelperTests.java @@ -9,19 +9,55 @@ package org.elasticsearch.reindex; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.IdFieldMapper; +import org.elasticsearch.index.reindex.BulkByScrollResponse; +import org.elasticsearch.index.reindex.BulkByScrollTask; +import org.elasticsearch.index.reindex.ReindexAction; +import org.elasticsearch.index.reindex.ReindexRequest; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; import java.io.IOException; import java.util.Collections; +import java.util.concurrent.atomic.AtomicReference; +import static java.util.Collections.emptySet; +import static org.elasticsearch.reindex.BulkByPaginatedSearchParallelizationHelper.executeSlicedAction; import static org.elasticsearch.reindex.BulkByPaginatedSearchParallelizationHelper.sliceIntoSubRequests; import static org.elasticsearch.search.RandomSearchRequestGenerator.randomSearchRequest; import static org.elasticsearch.search.RandomSearchRequestGenerator.randomSearchSourceBuilder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; public class BulkByPaginatedSearchParallelizationHelperTests extends ESTestCase { + + private ThreadPool threadPool; + private TaskManager taskManager; + + @Before + public void setUpTaskManager() { + threadPool = new TestThreadPool(getTestName()); + taskManager = new TaskManager(Settings.EMPTY, threadPool, emptySet()); + } + + @After + public void tearDownTaskManager() { + terminate(threadPool); + } + public void testSliceIntoSubRequests() throws IOException { SearchRequest searchRequest = randomSearchRequest( () -> randomSearchSourceBuilder(() -> null, () -> null, () -> null, Collections::emptyList, () -> null, () -> null) @@ -48,4 +84,60 @@ public void testSliceIntoSubRequests() throws IOException { currentSliceId++; } } + + /** + * When the task is a worker, executeSlicedAction invokes the worker action with the given remote version. + */ + public void testExecuteSlicedActionWithWorkerAndNonNullVersion() { + ReindexRequest request = new ReindexRequest(); + BulkByScrollTask task = (BulkByScrollTask) taskManager.register("reindex", ReindexAction.NAME, request); + task.setWorker(request.getRequestsPerSecond(), null); + + Version version = Version.CURRENT; + AtomicReference capturedVersion = new AtomicReference<>(); + ActionListener listener = ActionListener.noop(); + Client client = null; + DiscoveryNode node = DiscoveryNodeUtils.builder("node").roles(emptySet()).build(); + + executeSlicedAction(task, request, ReindexAction.INSTANCE, listener, client, node, version, capturedVersion::set); + + assertThat(capturedVersion.get(), sameInstance(version)); + } + + /** + * When the task is a worker and remote version is null (local reindex), the worker action receives null. + */ + public void testExecuteSlicedActionWithWorkerAndNullVersion() { + ReindexRequest request = new ReindexRequest(); + BulkByScrollTask task = (BulkByScrollTask) taskManager.register("reindex", ReindexAction.NAME, request); + task.setWorker(request.getRequestsPerSecond(), null); + + AtomicReference capturedVersion = new AtomicReference<>(Version.CURRENT); + ActionListener listener = ActionListener.noop(); + Client client = null; + DiscoveryNode node = DiscoveryNodeUtils.builder("node").roles(emptySet()).build(); + + executeSlicedAction(task, request, ReindexAction.INSTANCE, listener, client, node, null, capturedVersion::set); + + assertThat(capturedVersion.get(), nullValue()); + } + + /** + * When the task is neither a leader nor a worker (not initialized), executeSlicedAction throws. + */ + public void testExecuteSlicedActionThrowsWhenTaskNotInitialized() { + ReindexRequest request = new ReindexRequest(); + BulkByScrollTask task = (BulkByScrollTask) taskManager.register("reindex", ReindexAction.NAME, request); + // Do not call setWorker or setWorkerCount + + ActionListener listener = ActionListener.noop(); + Client client = null; + DiscoveryNode node = DiscoveryNodeUtils.builder("node").roles(emptySet()).build(); + + AssertionError e = expectThrows( + AssertionError.class, + () -> executeSlicedAction(task, request, ReindexAction.INSTANCE, listener, client, node, null, v -> {}) + ); + assertThat(e.getMessage(), containsString("initialized")); + } } diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexIdTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexIdTests.java index 824381ad08238..dbdff2c0a74b4 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexIdTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexIdTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.reindex; +import org.elasticsearch.Version; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ProjectState; import org.elasticsearch.cluster.metadata.ComponentTemplate; @@ -109,6 +110,18 @@ protected ReindexRequest request() { } private Reindexer.AsyncIndexBySearchAction action(ProjectState state) { - return new Reindexer.AsyncIndexBySearchAction(task, logger, null, null, threadPool, null, state, null, request(), listener()); + return new Reindexer.AsyncIndexBySearchAction( + task, + logger, + null, + null, + threadPool, + null, + state, + null, + request(), + listener(), + randomBoolean() ? null : Version.CURRENT + ); } } diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexMetadataTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexMetadataTests.java index d7abf0e734b70..ea0b44de0078e 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexMetadataTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexMetadataTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.reindex; +import org.elasticsearch.Version; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.Metadata; @@ -82,7 +83,8 @@ private class TestAction extends Reindexer.AsyncIndexBySearchAction { ClusterState.EMPTY_STATE.projectState(Metadata.DEFAULT_PROJECT_ID), null, request(), - listener() + listener(), + randomBoolean() ? null : Version.CURRENT ); } diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexScriptTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexScriptTests.java index b95896bd0298f..31abb00374662 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexScriptTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/ReindexScriptTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.reindex; +import org.elasticsearch.Version; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.Metadata; @@ -101,7 +102,8 @@ protected Reindexer.AsyncIndexBySearchAction action(ScriptService scriptService, ClusterState.EMPTY_STATE.projectState(Metadata.DEFAULT_PROJECT_ID), sslConfig, request, - listener() + listener(), + randomBoolean() ? null : Version.CURRENT ); } } diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/remote/RemoteReindexingUtilsTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/remote/RemoteReindexingUtilsTests.java new file mode 100644 index 0000000000000..9969a01c99d07 --- /dev/null +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/remote/RemoteReindexingUtilsTests.java @@ -0,0 +1,466 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.reindex.remote; + +import org.apache.http.ContentTooLongException; +import org.apache.http.HttpEntity; +import org.apache.http.RequestLine; +import org.apache.http.StatusLine; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.InputStreamEntity; +import org.apache.http.entity.StringEntity; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.client.ResponseListener; +import org.elasticsearch.client.RestClient; +import org.elasticsearch.common.BackoffPolicy; +import org.elasticsearch.common.io.FileSystemUtils; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.reindex.RejectAwareActionListener; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URL; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.elasticsearch.reindex.remote.RemoteReindexingUtils.wrapExceptionToPreserveStatus; +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RemoteReindexingUtilsTests extends ESTestCase { + + private static final Logger logger = LogManager.getLogger(RemoteReindexingUtilsTests.class); + + private ThreadPool threadPool; + private RestClient client; + + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()) { + @Override + public ExecutorService executor(String name) { + return EsExecutors.DIRECT_EXECUTOR_SERVICE; + } + + @Override + public Scheduler.ScheduledCancellable schedule(Runnable command, TimeValue delay, Executor executor) { + command.run(); + return null; + } + }; + client = mock(RestClient.class); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + + /** + * Verifies that lookupRemoteVersion correctly parses historical and + * forward-compatible main action responses. + */ + public void testLookupRemoteVersion() throws Exception { + assertLookupRemoteVersion(Version.fromString("0.20.5"), "main/0_20_5.json"); + assertLookupRemoteVersion(Version.fromString("0.90.13"), "main/0_90_13.json"); + assertLookupRemoteVersion(Version.fromString("1.7.5"), "main/1_7_5.json"); + assertLookupRemoteVersion(Version.fromId(2030399), "main/2_3_3.json"); + assertLookupRemoteVersion(Version.fromId(5000099), "main/5_0_0_alpha_3.json"); + assertLookupRemoteVersion(Version.fromId(5000099), "main/with_unknown_fields.json"); + } + + private void assertLookupRemoteVersion(Version expected, String resource) throws Exception { + AtomicBoolean called = new AtomicBoolean(); + URL url = Thread.currentThread().getContextClassLoader().getResource("responses/" + resource); + assertNotNull("missing test resource [" + resource + "]", url); + + HttpEntity entity = new InputStreamEntity(FileSystemUtils.openFileURLStream(url), ContentType.APPLICATION_JSON); + Response response = mock(Response.class); + when(response.getEntity()).thenReturn(entity); + + mockSuccess(response); + RemoteReindexingUtils.lookupRemoteVersion(RejectAwareActionListener.wrap(v -> { + assertEquals(expected, v); + called.set(true); + }, e -> fail(), e -> fail()), threadPool, client); + assertTrue("listener was not called", called.get()); + } + + /** + * Verifies that lookupRemoteVersion fails when the response does not include + * a Content-Type header, and that the error message includes the response body. + */ + public void testLookupRemoteVersionFailsWithoutContentType() throws Exception { + URL url = Thread.currentThread().getContextClassLoader().getResource("responses/main/0_20_5.json"); + assertNotNull(url); + + HttpEntity entity = new InputStreamEntity( + FileSystemUtils.openFileURLStream(url), + // intentionally no Content-Type + null + ); + + Response response = mock(Response.class); + when(response.getEntity()).thenReturn(entity); + mockSuccess(response); + + try { + RemoteReindexingUtils.lookupRemoteVersion( + RejectAwareActionListener.wrap( + v -> fail("Expected an exception yet one was not thrown"), + // We're expecting an exception, so no need to fail + e -> {}, + e -> {} + ), + threadPool, + client + ); + } catch (RuntimeException e) { + assertThat(e.getMessage(), containsString("Response didn't include Content-Type: body={")); + } catch (Exception e) { + fail("Expected RuntimeException"); + } + } + + /** + * Verifies that HTTP 429 responses are routed to onRejection rather than onFailure. + */ + public void testLookupRemoteVersionTooManyRequestsTriggersRejection() throws Exception { + AtomicBoolean rejected = new AtomicBoolean(); + Response response = mock(Response.class); + when(response.getEntity()).thenReturn(null); + + StatusLine statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(RestStatus.TOO_MANY_REQUESTS.getStatus()); + when(response.getStatusLine()).thenReturn(statusLine); + + // Mocks used in the ResponseException constructor + RequestLine requestLine = mock(RequestLine.class); + when(requestLine.getMethod()).thenReturn("mock"); + when(response.getRequestLine()).thenReturn(requestLine); + mockFailure(new ResponseException(response)); + + RemoteReindexingUtils.lookupRemoteVersion( + RejectAwareActionListener.wrap(v -> fail("unexpected success"), e -> fail("unexpected failure"), e -> rejected.set(true)), + threadPool, + client + ); + assertTrue("onRejection was not called", rejected.get()); + } + + /** + * Verifies that non-429 HTTP errors are routed to onFailure. + */ + public void testLookupRemoteVersionHttpErrorTriggersFailure() throws Exception { + StatusLine statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(RestStatus.BAD_REQUEST.getStatus()); + Response response = mock(Response.class); + when(response.getStatusLine()).thenReturn(statusLine); + when(response.getEntity()).thenReturn(new StringEntity("bad request", ContentType.TEXT_PLAIN)); + + // Mocks used in the ResponseException constructor + RequestLine requestLine = mock(RequestLine.class); + when(requestLine.getMethod()).thenReturn("mock"); + when(response.getRequestLine()).thenReturn(requestLine); + mockFailure(new ResponseException(response)); + + RemoteReindexingUtils.lookupRemoteVersion(RejectAwareActionListener.wrap(v -> fail(), ex -> { + assertTrue(ex instanceof ElasticsearchException); + assertEquals(RestStatus.BAD_REQUEST, ((ElasticsearchStatusException) ex).status()); + }, ex -> fail()), threadPool, client); + } + + /** + * Verifies that ContentTooLongException is translated into a user-facing IllegalArgumentException. + */ + public void testContentTooLongExceptionIsWrapped() { + mockFailure(new ContentTooLongException("too large")); + + RemoteReindexingUtils.lookupRemoteVersion(RejectAwareActionListener.wrap(v -> fail(), ex -> { + assertTrue(ex instanceof IllegalArgumentException); + assertThat(ex.getMessage(), containsString("Remote responded with a chunk that was too large")); + }, ex -> fail()), threadPool, client); + } + + public void testInvalidJsonThrowsElasticsearchException() { + HttpEntity entity = new StringEntity("this is not json", ContentType.APPLICATION_JSON); + Response response = mock(Response.class); + when(response.getEntity()).thenReturn(entity); + mockSuccess(response); + + RemoteReindexingUtils.lookupRemoteVersion(RejectAwareActionListener.wrap(v -> fail(), ex -> { + assertTrue(ex instanceof ElasticsearchException); + assertThat(ex.getMessage(), containsString("remote is likely not an Elasticsearch instance")); + }, ex -> fail()), threadPool, client); + } + + /** + * Verifies that IOExceptions during response deserialization are surfaced correctly. + */ + public void testIOExceptionDuringDeserialization() throws Exception { + HttpEntity entity = mock(HttpEntity.class); + when(entity.getContent()).thenThrow(new IOException("boom")); + Response response = mock(Response.class); + when(response.getEntity()).thenReturn(entity); + mockSuccess(response); + + RemoteReindexingUtils.lookupRemoteVersion(RejectAwareActionListener.wrap(v -> fail(), ex -> { + assertTrue(ex instanceof ElasticsearchException); + assertThat(ex.getMessage(), containsString("Error deserializing response")); + }, ex -> fail()), threadPool, client); + } + + public void testWrapExceptionToPreserveStatus() throws IOException { + Exception cause = new Exception(); + + // Successfully get the status without a body + RestStatus status = randomFrom(RestStatus.values()); + ElasticsearchStatusException wrapped = wrapExceptionToPreserveStatus(status.getStatus(), null, cause); + assertEquals(status, wrapped.status()); + assertEquals(cause, wrapped.getCause()); + assertEquals("No error body.", wrapped.getMessage()); + + // Successfully get the status without a body + HttpEntity okEntity = new StringEntity("test body", ContentType.TEXT_PLAIN); + wrapped = wrapExceptionToPreserveStatus(status.getStatus(), okEntity, cause); + assertEquals(status, wrapped.status()); + assertEquals(cause, wrapped.getCause()); + assertEquals("body=test body", wrapped.getMessage()); + + // Successfully get the status with a broken body + IOException badEntityException = new IOException(); + HttpEntity badEntity = mock(HttpEntity.class); + when(badEntity.getContent()).thenThrow(badEntityException); + wrapped = wrapExceptionToPreserveStatus(status.getStatus(), badEntity, cause); + assertEquals(status, wrapped.status()); + assertEquals(cause, wrapped.getCause()); + assertEquals("Failed to extract body.", wrapped.getMessage()); + assertEquals(badEntityException, wrapped.getSuppressed()[0]); + + // Fail to get the status without a body + int notAnHttpStatus = -1; + assertNull(RestStatus.fromCode(notAnHttpStatus)); + wrapped = wrapExceptionToPreserveStatus(notAnHttpStatus, null, cause); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, wrapped.status()); + assertEquals(cause, wrapped.getCause()); + assertEquals("Couldn't extract status [" + notAnHttpStatus + "]. No error body.", wrapped.getMessage()); + + // Fail to get the status without a body + wrapped = wrapExceptionToPreserveStatus(notAnHttpStatus, okEntity, cause); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, wrapped.status()); + assertEquals(cause, wrapped.getCause()); + assertEquals("Couldn't extract status [" + notAnHttpStatus + "]. body=test body", wrapped.getMessage()); + + // Fail to get the status with a broken body + wrapped = wrapExceptionToPreserveStatus(notAnHttpStatus, badEntity, cause); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, wrapped.status()); + assertEquals(cause, wrapped.getCause()); + assertEquals("Couldn't extract status [" + notAnHttpStatus + "]. Failed to extract body.", wrapped.getMessage()); + assertEquals(badEntityException, wrapped.getSuppressed()[0]); + } + + public void testBodyMessageWithNullEntity() throws Exception { + String message = RemoteReindexingUtils.bodyMessage(null); + assertEquals("No error body.", message); + } + + public void testBodyMessageWithReadableEntity() throws Exception { + String testBody = randomAlphanumericOfLength(10); + HttpEntity entity = new StringEntity(testBody, ContentType.TEXT_PLAIN); + + String message = RemoteReindexingUtils.bodyMessage(entity); + + assertEquals("body=" + testBody, message); + } + + public void testBodyMessageWithIOException() throws Exception { + IOException expected = new IOException("Exception"); + + HttpEntity entity = mock(HttpEntity.class); + when(entity.getContent()).thenThrow(expected); + + IOException actual = expectThrows(IOException.class, () -> RemoteReindexingUtils.bodyMessage(entity)); + + assertSame(expected, actual); + } + + /** + * Verifies that lookupRemoteVersionWithRetries retries on 429 and eventually succeeds. + */ + public void testLookupRemoteVersionWithRetriesSucceedsOnRetry() throws Exception { + Response successResponse = successResponse("main/1_7_5.json"); + Response rejectionResponse = rejectionResponse429(); + AtomicInteger callCount = new AtomicInteger(0); + + doAnswer(inv -> { + ResponseListener listener = inv.getArgument(1); + if (callCount.getAndIncrement() == 0) { + listener.onFailure(new ResponseException(rejectionResponse)); + } else { + listener.onSuccess(successResponse); + } + return null; + }).when(client).performRequestAsync(any(), any()); + + AtomicBoolean success = new AtomicBoolean(false); + + RemoteReindexingUtils.lookupRemoteVersionWithRetries( + logger, + BackoffPolicy.constantBackoff(TimeValue.ZERO, 1), + threadPool, + client, + RejectAwareActionListener.wrap(v -> { + assertEquals(Version.fromString("1.7.5"), v); + success.set(true); + }, e -> fail("unexpected failure"), e -> fail("unexpected rejection")) + ); + + assertTrue("listener should have received success", success.get()); + assertEquals("performRequestAsync should be called twice (initial + 1 retry)", 2, callCount.get()); + } + + /** + * Verifies that lookupRemoteVersionWithRetries propagates failure when retries are exhausted. + */ + public void testLookupRemoteVersionWithRetriesExhaustedPropagatesFailure() throws Exception { + Response rejectionResponse = rejectionResponse429(); + doAnswer(inv -> { + ((ResponseListener) inv.getArgument(1)).onFailure(new ResponseException(rejectionResponse)); + return null; + }).when(client).performRequestAsync(any(), any()); + + AtomicBoolean failed = new AtomicBoolean(false); + + RemoteReindexingUtils.lookupRemoteVersionWithRetries( + logger, + BackoffPolicy.constantBackoff(TimeValue.ZERO, 1), + threadPool, + client, + RejectAwareActionListener.wrap(v -> fail("unexpected success"), e -> { + assertTrue(e instanceof ElasticsearchStatusException); + assertEquals(RestStatus.TOO_MANY_REQUESTS, ((ElasticsearchStatusException) e).status()); + failed.set(true); + }, e -> fail("should have propagated as failure after retries exhausted")) + ); + + assertTrue("listener should have received failure", failed.get()); + verify(client, times(2)).performRequestAsync(any(), any()); + } + + /** + * Verifies that non-429 errors do not trigger retries. + */ + public void testLookupRemoteVersionWithRetriesNon429DoesNotRetry() throws Exception { + Response badRequestResponse = mock(Response.class); + StatusLine statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(RestStatus.INTERNAL_SERVER_ERROR.getStatus()); + when(badRequestResponse.getStatusLine()).thenReturn(statusLine); + when(badRequestResponse.getEntity()).thenReturn(new StringEntity("error", ContentType.TEXT_PLAIN)); + RequestLine requestLine = mock(RequestLine.class); + when(requestLine.getMethod()).thenReturn("GET"); + when(badRequestResponse.getRequestLine()).thenReturn(requestLine); + + mockFailure(new ResponseException(badRequestResponse)); + + RemoteReindexingUtils.lookupRemoteVersionWithRetries( + logger, + BackoffPolicy.constantBackoff(TimeValue.ZERO, 5), + threadPool, + client, + RejectAwareActionListener.wrap(v -> fail(), e -> { + assertTrue(e instanceof ElasticsearchStatusException); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, ((ElasticsearchStatusException) e).status()); + }, e -> fail()) + ); + + verify(client, times(1)).performRequestAsync(any(), any()); + } + + /** + * Verifies that success on the first attempt does not invoke countRetry. + */ + public void testLookupRemoteVersionWithRetriesSucceedsOnFirstCall() throws Exception { + Response successResponse = successResponse("main/2_3_3.json"); + mockSuccess(successResponse); + + AtomicBoolean success = new AtomicBoolean(false); + + RemoteReindexingUtils.lookupRemoteVersionWithRetries( + logger, + BackoffPolicy.constantBackoff(TimeValue.ZERO, 5), + threadPool, + client, + RejectAwareActionListener.wrap(v -> { + assertEquals(Version.fromString("2.3.3"), v); + success.set(true); + }, e -> fail(), e -> fail()) + ); + + assertTrue("listener should have received success", success.get()); + verify(client, times(1)).performRequestAsync(any(), any()); + } + + private Response successResponse(String resource) throws Exception { + URL url = Thread.currentThread().getContextClassLoader().getResource("responses/" + resource); + assertNotNull("missing test resource [" + resource + "]", url); + HttpEntity entity = new InputStreamEntity(FileSystemUtils.openFileURLStream(url), ContentType.APPLICATION_JSON); + Response response = mock(Response.class); + when(response.getEntity()).thenReturn(entity); + return response; + } + + private Response rejectionResponse429() { + Response response = mock(Response.class); + when(response.getEntity()).thenReturn(null); + StatusLine statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(RestStatus.TOO_MANY_REQUESTS.getStatus()); + when(response.getStatusLine()).thenReturn(statusLine); + RequestLine requestLine = mock(RequestLine.class); + when(requestLine.getMethod()).thenReturn("GET"); + when(response.getRequestLine()).thenReturn(requestLine); + return response; + } + + private void mockSuccess(Response response) { + doAnswer(inv -> { + ((ResponseListener) inv.getArgument(1)).onSuccess(response); + return null; + }).when(client).performRequestAsync(any(), any()); + } + + private void mockFailure(Exception e) { + doAnswer(inv -> { + ((ResponseListener) inv.getArgument(1)).onFailure(e); + return null; + }).when(client).performRequestAsync(any(), any()); + } +} diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/remote/RemoteScrollablePaginatedHitSourceTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/remote/RemoteScrollablePaginatedHitSourceTests.java index da3b26bba17d5..90b748767bf8f 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/remote/RemoteScrollablePaginatedHitSourceTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/remote/RemoteScrollablePaginatedHitSourceTests.java @@ -10,7 +10,6 @@ package org.elasticsearch.reindex.remote; import org.apache.http.ContentTooLongException; -import org.apache.http.HttpEntity; import org.apache.http.HttpEntityEnclosingRequest; import org.apache.http.HttpHost; import org.apache.http.HttpResponse; @@ -20,14 +19,12 @@ import org.apache.http.concurrent.FutureCallback; import org.apache.http.entity.ContentType; import org.apache.http.entity.InputStreamEntity; -import org.apache.http.entity.StringEntity; import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; import org.apache.http.message.BasicHttpResponse; import org.apache.http.message.BasicStatusLine; import org.apache.http.nio.protocol.HttpAsyncRequestProducer; import org.apache.http.nio.protocol.HttpAsyncResponseConsumer; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.client.HeapBufferedAsyncResponseConsumer; @@ -56,7 +53,6 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import java.io.IOException; import java.io.InputStreamReader; import java.net.URL; import java.nio.charset.StandardCharsets; @@ -67,13 +63,13 @@ import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Stream; import static org.elasticsearch.core.TimeValue.timeValueMillis; import static org.elasticsearch.core.TimeValue.timeValueMinutes; -import static org.hamcrest.Matchers.containsString; +import static org.elasticsearch.reindex.remote.RemoteReindexingUtils.execute; +import static org.elasticsearch.reindex.remote.RemoteResponseParsers.RESPONSE_PARSER; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -130,26 +126,6 @@ public void validateAllConsumed() { assertThat(responseQueue, empty()); } - public void testLookupRemoteVersion() throws Exception { - assertLookupRemoteVersion(Version.fromString("0.20.5"), "main/0_20_5.json"); - assertLookupRemoteVersion(Version.fromString("0.90.13"), "main/0_90_13.json"); - assertLookupRemoteVersion(Version.fromString("1.7.5"), "main/1_7_5.json"); - assertLookupRemoteVersion(Version.fromId(2030399), "main/2_3_3.json"); - // assert for V_5_0_0 (no qualifier) since we no longer consider qualifier in Version since 7 - assertLookupRemoteVersion(Version.fromId(5000099), "main/5_0_0_alpha_3.json"); - // V_5_0_0 since we no longer consider qualifier in Version - assertLookupRemoteVersion(Version.fromId(5000099), "main/with_unknown_fields.json"); - } - - private void assertLookupRemoteVersion(Version expected, String s) throws Exception { - AtomicBoolean called = new AtomicBoolean(); - sourceWithMockedRemoteCall(false, ContentType.APPLICATION_JSON, s).lookupRemoteVersion(wrapAsListener(v -> { - assertEquals(expected, v); - called.set(true); - })); - assertTrue(called.get()); - } - public void testParseStartOk() throws Exception { AtomicBoolean called = new AtomicBoolean(); sourceWithMockedRemoteCall("start_ok.json").doStart(wrapAsListener(r -> { @@ -300,6 +276,23 @@ public void testParseFailureWithStatus() throws Exception { assertTrue(called.get()); } + /** + * When constructed with a non-null initial remote version, doStart skips the version lookup and issues + * the initial search directly. Only one HTTP request is made (the search), not two (version + search). + */ + public void testDoStartSkipsVersionLookupWhenInitialRemoteVersionSet() throws Exception { + AtomicBoolean called = new AtomicBoolean(); + RemoteScrollablePaginatedHitSource hitSource = sourceWithInitialRemoteVersion(Version.CURRENT, "start_ok.json"); + hitSource.doStart(wrapAsListener(r -> { + assertFalse(r.isTimedOut()); + assertEquals(FAKE_SCROLL_ID, r.getScrollId()); + assertEquals(4, r.getTotalHits()); + assertThat(r.getHits(), hasSize(1)); + called.set(true); + })); + assertTrue(called.get()); + } + public void testParseRequestFailure() throws Exception { AtomicBoolean called = new AtomicBoolean(); Consumer checkResponse = r -> { @@ -373,59 +366,6 @@ public void testThreadContextRestored() throws Exception { assertTrue(called.get()); } - public void testWrapExceptionToPreserveStatus() throws IOException { - Exception cause = new Exception(); - - // Successfully get the status without a body - RestStatus status = randomFrom(RestStatus.values()); - ElasticsearchStatusException wrapped = RemoteScrollablePaginatedHitSource.wrapExceptionToPreserveStatus( - status.getStatus(), - null, - cause - ); - assertEquals(status, wrapped.status()); - assertEquals(cause, wrapped.getCause()); - assertEquals("No error body.", wrapped.getMessage()); - - // Successfully get the status without a body - HttpEntity okEntity = new StringEntity("test body", ContentType.TEXT_PLAIN); - wrapped = RemoteScrollablePaginatedHitSource.wrapExceptionToPreserveStatus(status.getStatus(), okEntity, cause); - assertEquals(status, wrapped.status()); - assertEquals(cause, wrapped.getCause()); - assertEquals("body=test body", wrapped.getMessage()); - - // Successfully get the status with a broken body - IOException badEntityException = new IOException(); - HttpEntity badEntity = mock(HttpEntity.class); - when(badEntity.getContent()).thenThrow(badEntityException); - wrapped = RemoteScrollablePaginatedHitSource.wrapExceptionToPreserveStatus(status.getStatus(), badEntity, cause); - assertEquals(status, wrapped.status()); - assertEquals(cause, wrapped.getCause()); - assertEquals("Failed to extract body.", wrapped.getMessage()); - assertEquals(badEntityException, wrapped.getSuppressed()[0]); - - // Fail to get the status without a body - int notAnHttpStatus = -1; - assertNull(RestStatus.fromCode(notAnHttpStatus)); - wrapped = RemoteScrollablePaginatedHitSource.wrapExceptionToPreserveStatus(notAnHttpStatus, null, cause); - assertEquals(RestStatus.INTERNAL_SERVER_ERROR, wrapped.status()); - assertEquals(cause, wrapped.getCause()); - assertEquals("Couldn't extract status [" + notAnHttpStatus + "]. No error body.", wrapped.getMessage()); - - // Fail to get the status without a body - wrapped = RemoteScrollablePaginatedHitSource.wrapExceptionToPreserveStatus(notAnHttpStatus, okEntity, cause); - assertEquals(RestStatus.INTERNAL_SERVER_ERROR, wrapped.status()); - assertEquals(cause, wrapped.getCause()); - assertEquals("Couldn't extract status [" + notAnHttpStatus + "]. body=test body", wrapped.getMessage()); - - // Fail to get the status with a broken body - wrapped = RemoteScrollablePaginatedHitSource.wrapExceptionToPreserveStatus(notAnHttpStatus, badEntity, cause); - assertEquals(RestStatus.INTERNAL_SERVER_ERROR, wrapped.status()); - assertEquals(cause, wrapped.getCause()); - assertEquals("Couldn't extract status [" + notAnHttpStatus + "]. Failed to extract body.", wrapped.getMessage()); - assertEquals(badEntityException, wrapped.getSuppressed()[0]); - } - @SuppressWarnings({ "unchecked", "rawtypes" }) public void testTooLargeResponse() throws Exception { ContentTooLongException tooLong = new ContentTooLongException("too long!"); @@ -458,15 +398,6 @@ public Future answer(InvocationOnMock invocationOnMock) throws Thr assertTrue(responseQueue.isEmpty()); } - public void testNoContentTypeIsError() { - RuntimeException e = expectListenerFailure( - RuntimeException.class, - (RejectAwareActionListener listener) -> sourceWithMockedRemoteCall(false, null, "main/0_20_5.json") - .lookupRemoteVersion(listener) - ); - assertThat(e.getMessage(), containsString("Response didn't include Content-Type: body={")); - } - public void testInvalidJsonThinksRemoteIsNotES() throws Exception { sourceWithMockedRemoteCall("some_text.txt").start(); Throwable e = failureQueue.poll(); @@ -483,7 +414,8 @@ public void testUnexpectedJsonThinksRemoteIsNotES() throws Exception { public void testCleanupSuccessful() throws Exception { AtomicBoolean cleanupCallbackCalled = new AtomicBoolean(); RestClient client = mock(RestClient.class); - TestRemoteScrollablePaginatedHitSource paginatedHitSource = new TestRemoteScrollablePaginatedHitSource(client); + RemoteInfo remoteInfo = remoteInfo(); + TestRemoteScrollablePaginatedHitSource paginatedHitSource = new TestRemoteScrollablePaginatedHitSource(client, remoteInfo); paginatedHitSource.cleanup(() -> cleanupCallbackCalled.set(true)); verify(client).close(); assertTrue(cleanupCallbackCalled.get()); @@ -493,7 +425,8 @@ public void testCleanupFailure() throws Exception { AtomicBoolean cleanupCallbackCalled = new AtomicBoolean(); RestClient client = mock(RestClient.class); doThrow(new RuntimeException("test")).when(client).close(); - TestRemoteScrollablePaginatedHitSource paginatedHitSource = new TestRemoteScrollablePaginatedHitSource(client); + RemoteInfo remoteInfo = remoteInfo(); + TestRemoteScrollablePaginatedHitSource paginatedHitSource = new TestRemoteScrollablePaginatedHitSource(client, remoteInfo); paginatedHitSource.cleanup(() -> cleanupCallbackCalled.set(true)); verify(client).close(); assertTrue(cleanupCallbackCalled.get()); @@ -572,13 +505,22 @@ private RemoteScrollablePaginatedHitSource sourceWithMockedClient(boolean mockRe .setHttpClientConfigCallback(httpClientBuilder -> clientBuilder) .build(); - TestRemoteScrollablePaginatedHitSource paginatedHitSource = new TestRemoteScrollablePaginatedHitSource(restClient) { + RemoteInfo remoteInfo = remoteInfo(); + TestRemoteScrollablePaginatedHitSource paginatedHitSource = new TestRemoteScrollablePaginatedHitSource(restClient, remoteInfo) { @Override - void lookupRemoteVersion(RejectAwareActionListener listener) { + protected void doStart(RejectAwareActionListener searchListener) { + // Short‑circuit version lookup by setting it to current if (mockRemoteVersion) { - listener.onResponse(Version.CURRENT); + remoteVersion = Version.CURRENT; + execute( + RemoteRequestBuilders.initialSearch(searchRequest, remoteInfo.getQuery(), remoteVersion), + RESPONSE_PARSER, + RejectAwareActionListener.withResponseHandler(searchListener, r -> onStartResponse(searchListener, r)), + threadPool, + restClient + ); } else { - super.lookupRemoteVersion(listener); + super.doStart(searchListener); } } }; @@ -588,6 +530,102 @@ void lookupRemoteVersion(RejectAwareActionListener listener) { return paginatedHitSource; } + /** + * Creates a RemoteScrollablePaginatedHitSource with a pre-resolved initial remote version so that doStart skips the version lookup. + * The mock client serves only the given response paths (one request = one path when using initial version). + */ + private RemoteScrollablePaginatedHitSource sourceWithInitialRemoteVersion(Version initialRemoteVersion, String... paths) + throws Exception { + return sourceWithInitialRemoteVersion(initialRemoteVersion, ContentType.APPLICATION_JSON, paths); + } + + @SuppressWarnings("unchecked") + private RemoteScrollablePaginatedHitSource sourceWithInitialRemoteVersion( + Version initialRemoteVersion, + ContentType contentType, + String... paths + ) throws Exception { + URL[] resources = new URL[paths.length]; + for (int i = 0; i < paths.length; i++) { + resources[i] = Thread.currentThread().getContextClassLoader().getResource("responses/" + paths[i].replace("fail:", "")); + if (resources[i] == null) { + throw new IllegalArgumentException("Couldn't find [" + paths[i] + "]"); + } + } + + CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class); + when( + httpClient.execute( + any(HttpAsyncRequestProducer.class), + any(HttpAsyncResponseConsumer.class), + any(HttpClientContext.class), + any(FutureCallback.class) + ) + ).thenAnswer(new Answer>() { + int responseCount = 0; + + @Override + public Future answer(InvocationOnMock invocationOnMock) throws Throwable { + threadPool.getThreadContext().stashContext(); + FutureCallback futureCallback = (FutureCallback) invocationOnMock.getArguments()[3]; + HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0]; + HttpEntityEnclosingRequest request = (HttpEntityEnclosingRequest) requestProducer.generateRequest(); + URL resource = resources[responseCount]; + String path = paths[responseCount++]; + ProtocolVersion protocolVersion = new ProtocolVersion("http", 1, 1); + if (path.startsWith("fail:")) { + String body = Streams.copyToString(new InputStreamReader(request.getEntity().getContent(), StandardCharsets.UTF_8)); + if (path.equals("fail:rejection.json")) { + StatusLine statusLine = new BasicStatusLine(protocolVersion, RestStatus.TOO_MANY_REQUESTS.getStatus(), ""); + futureCallback.completed(new BasicHttpResponse(statusLine)); + } else { + futureCallback.failed(new RuntimeException(body)); + } + } else { + StatusLine statusLine = new BasicStatusLine(protocolVersion, 200, ""); + HttpResponse httpResponse = new BasicHttpResponse(statusLine); + httpResponse.setEntity(new InputStreamEntity(FileSystemUtils.openFileURLStream(resource), contentType)); + futureCallback.completed(httpResponse); + } + return null; + } + }); + + HttpAsyncClientBuilder clientBuilder = mock(HttpAsyncClientBuilder.class); + when(clientBuilder.build()).thenReturn(httpClient); + RestClient restClient = RestClient.builder(new HttpHost("localhost", 9200)) + .setHttpClientConfigCallback(httpClientBuilder -> clientBuilder) + .build(); + + return new RemoteScrollablePaginatedHitSource( + logger, + backoff(), + threadPool, + this::countRetry, + responseQueue::add, + failureQueue::add, + restClient, + remoteInfo(), + searchRequest, + initialRemoteVersion + ); + } + + private RemoteInfo remoteInfo() { + return new RemoteInfo( + "http", + randomAlphaOfLength(8), + randomIntBetween(4000, 9000), + null, + new BytesArray("{}"), + null, + null, + Map.of(), + TimeValue.timeValueSeconds(randomIntBetween(5, 30)), + TimeValue.timeValueSeconds(randomIntBetween(5, 30)) + ); + } + private BackoffPolicy backoff() { return BackoffPolicy.constantBackoff(timeValueMillis(0), retriesAllowed); } @@ -597,7 +635,7 @@ private void countRetry() { } private class TestRemoteScrollablePaginatedHitSource extends RemoteScrollablePaginatedHitSource { - TestRemoteScrollablePaginatedHitSource(RestClient client) { + TestRemoteScrollablePaginatedHitSource(RestClient client, RemoteInfo remoteInfo) { super( RemoteScrollablePaginatedHitSourceTests.this.logger, backoff(), @@ -606,19 +644,9 @@ private class TestRemoteScrollablePaginatedHitSource extends RemoteScrollablePag responseQueue::add, failureQueue::add, client, - new RemoteInfo( - "http", - randomAlphaOfLength(8), - randomIntBetween(4000, 9000), - null, - new BytesArray("{}"), - null, - null, - Map.of(), - TimeValue.timeValueSeconds(randomIntBetween(5, 30)), - TimeValue.timeValueSeconds(randomIntBetween(5, 30)) - ), - RemoteScrollablePaginatedHitSourceTests.this.searchRequest + remoteInfo, + RemoteScrollablePaginatedHitSourceTests.this.searchRequest, + randomBoolean() ? Version.CURRENT : null ); } } @@ -626,15 +654,4 @@ private class TestRemoteScrollablePaginatedHitSource extends RemoteScrollablePag private RejectAwareActionListener wrapAsListener(Consumer consumer) { return RejectAwareActionListener.wrap(consumer::accept, ESTestCase::fail, ESTestCase::fail); } - - @SuppressWarnings("unchecked") - private T expectListenerFailure(Class expectedException, Consumer> subject) { - AtomicReference exception = new AtomicReference<>(); - subject.accept(RejectAwareActionListener.wrap(r -> fail(), e -> { - assertThat(e, instanceOf(expectedException)); - assertTrue(exception.compareAndSet(null, (T) e)); - }, e -> fail())); - assertNotNull(exception.get()); - return exception.get(); - } }