diff --git a/CHANGELOG.md b/CHANGELOG.md index 74e801a89c0fa..58630dda401bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add getWrappedScorer method to ProfileScorer for plugin access to wrapped scorers ([#20548](https://github.com/opensearch-project/OpenSearch/issues/20548)) - Support expected cluster name with validation in CCS Sniff mode ([#20532](https://github.com/opensearch-project/OpenSearch/pull/20532)) - Add security policy to allow `accessUnixDomainSocket` in `transport-grpc` module ([#20463](https://github.com/opensearch-project/OpenSearch/pull/20463)) +- [Workload Management] Enhance Scroll API support for autotagging ([#20151](https://github.com/opensearch-project/OpenSearch/pull/20151)) ### Changed - Move Randomness from server to libs/common ([#20570](https://github.com/opensearch-project/OpenSearch/pull/20570)) diff --git a/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java b/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java index 4f8dfa89027ee..ae3c8a1acf1d6 100644 --- a/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java +++ b/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java @@ -439,6 +439,50 @@ public void testDeleteRuleForNonexistentId() throws Exception { assertTrue("Expected error message for nonexistent rule ID", exception.getMessage().contains("no such index")); } + public void testScrollRequestsAreAlsoTagged() throws Exception { + String workloadGroupId = "wlm_auto_tag_scroll"; + String ruleId = "wlm_auto_tag_scroll_rule"; + String indexName = "scroll_tagged_index"; + + setWlmMode("enabled"); + + WorkloadGroup workloadGroup = createWorkloadGroup("scroll_tagging_group", workloadGroupId); + updateWorkloadGroupInClusterState(PUT, workloadGroup); + + FeatureType featureType = AutoTaggingRegistry.getFeatureType(WorkloadGroupFeatureType.NAME); + createRule(ruleId, "scroll tagging rule", indexName, featureType, workloadGroupId); + + indexDocument(indexName); + + assertBusy(() -> { + int completionsBefore = getCompletions(workloadGroupId); + + SearchResponse initial = client().prepareSearch(indexName) + .setQuery(QueryBuilders.matchAllQuery()) + .setScroll(TimeValue.timeValueMinutes(1)) + .setSize(1) + .get(); + String scrollId = initial.getScrollId(); + assertNotNull("scrollId must not be null", scrollId); + + try { + int afterInitialSearch = getCompletions(workloadGroupId); + assertTrue("Expected completions to increase after initial search with scroll", afterInitialSearch > completionsBefore); + + SearchResponse scrollResp = client().prepareSearchScroll(scrollId).setScroll(TimeValue.timeValueMinutes(1)).get(); + String nextScrollId = scrollResp.getScrollId(); + if (nextScrollId != null && !nextScrollId.isEmpty()) { + scrollId = nextScrollId; + } + + int afterScroll = getCompletions(workloadGroupId); + assertTrue("Expected completions to increase after scroll request as well", afterScroll > afterInitialSearch); + } finally { + clearScroll(scrollId); + } + }); + } + // Helper functions private void createRule(String ruleId, String ruleName, String indexPattern, FeatureType featureType, String workloadGroupId) throws Exception { diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java index c6294ed7ac242..bea19e17073ec 100644 --- a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java @@ -11,6 +11,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.IndicesRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.action.support.ActionFilter; import org.opensearch.action.support.ActionFilterChain; import org.opensearch.action.support.ActionRequestMetadata; @@ -19,6 +20,7 @@ import org.opensearch.plugin.wlm.rule.attribute_extractor.IndicesExtractor; import org.opensearch.plugin.wlm.spi.AttributeExtractorExtension; import org.opensearch.rule.InMemoryRuleProcessingService; +import org.opensearch.rule.RuleAttribute; import org.opensearch.rule.attribute_extractor.AttributeExtractor; import org.opensearch.rule.autotagging.Attribute; import org.opensearch.rule.autotagging.FeatureType; @@ -28,6 +30,7 @@ import org.opensearch.wlm.WorkloadGroupTask; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; @@ -80,14 +83,39 @@ public void app ActionListener listener, ActionFilterChain chain ) { - final boolean isValidRequest = request instanceof SearchRequest; + final boolean isSearchRequest = request instanceof SearchRequest; + final boolean isSearchScrollRequest = request instanceof SearchScrollRequest; + final boolean isValidRequest = isSearchRequest || isSearchScrollRequest; if (!isValidRequest || wlmClusterSettingValuesProvider.getWlmMode() == WlmMode.DISABLED) { chain.proceed(task, action, request, listener); return; } List> attributeExtractors = new ArrayList<>(); - attributeExtractors.add(new IndicesExtractor((IndicesRequest) request)); + if (isSearchRequest) { + attributeExtractors.add(new IndicesExtractor((IndicesRequest) request)); + } else { + // Scroll: recover the original user-provided indices from ParsedScrollId + final String[] originalIndices = ((SearchScrollRequest) request).originalIndicesOrEmpty(); + if (originalIndices.length > 0) { + attributeExtractors.add(new AttributeExtractor<>() { + @Override + public Attribute getAttribute() { + return RuleAttribute.INDEX_PATTERN; + } + + @Override + public Iterable extract() { + return Arrays.asList(originalIndices); + } + + @Override + public LogicalOperator getLogicalOperator() { + return LogicalOperator.AND; + } + }); + } + } if (featureType.getAllowedAttributesRegistry().containsKey(PRINCIPAL_ATTRIBUTE_NAME)) { Attribute attribute = featureType.getAllowedAttributesRegistry().get(PRINCIPAL_ATTRIBUTE_NAME); diff --git a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java index ed5e8e25843ea..40995d70c0848 100644 --- a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java +++ b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java @@ -11,6 +11,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.action.support.ActionFilterChain; import org.opensearch.action.support.ActionRequestMetadata; import org.opensearch.common.util.concurrent.ThreadContext; @@ -18,6 +19,8 @@ import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.rule.InMemoryRuleProcessingService; +import org.opensearch.rule.RuleAttribute; +import org.opensearch.rule.attribute_extractor.AttributeExtractor; import org.opensearch.rule.autotagging.Attribute; import org.opensearch.rule.autotagging.FeatureType; import org.opensearch.rule.storage.AttributeValueStoreFactory; @@ -29,11 +32,15 @@ import org.opensearch.wlm.WorkloadGroupTask; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.anyList; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -93,6 +100,39 @@ public void testApplyForInValidRequest() { verify(ruleProcessingService, times(0)).evaluateLabel(anyList()); } + public void testApplyForScrollRequestWithOriginalIndices() { + SearchScrollRequest request = mock(SearchScrollRequest.class); + ActionFilterChain chain = mock(TestActionFilterChain.class); + + @SuppressWarnings("unchecked") + ActionRequestMetadata metadata = mock(ActionRequestMetadata.class); + when(request.originalIndicesOrEmpty()).thenReturn(new String[] { "logs-scroll-index" }); + + try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { + doAnswer(inv -> { + @SuppressWarnings("unchecked") + List> extractors = inv.getArgument(0); + + assertNotNull(extractors); + assertEquals(1, extractors.size()); + + AttributeExtractor ex = extractors.get(0); + assertEquals(RuleAttribute.INDEX_PATTERN, ex.getAttribute()); + + List values = new ArrayList<>(); + ex.extract().forEach(values::add); + assertEquals(List.of("logs-scroll-index"), values); + + return Optional.of("ScrollQG_ID"); + }).when(ruleProcessingService).evaluateLabel(any()); + + autoTaggingActionFilter.apply(mock(Task.class), "Test", request, metadata, null, chain); + + assertEquals("ScrollQG_ID", threadPool.getThreadContext().getHeader(WorkloadGroupTask.WORKLOAD_GROUP_ID_HEADER)); + verify(ruleProcessingService, times(1)).evaluateLabel(anyList()); + } + } + public enum WLMFeatureType implements FeatureType { WLM; diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 59bb88b0f6f67..00b6280af7323 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -774,7 +774,9 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); } else { final Version minNodeVersion = clusterState.nodes().getMinNodeVersion(); - final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) : null; + final String scrollId = request.scroll() != null + ? TransportSearchHelper.buildScrollId(queryResults, request.indices(), minNodeVersion) + : null; final String searchContextId; if (buildPointInTimeFromSearchResults()) { searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minNodeVersion); diff --git a/server/src/main/java/org/opensearch/action/search/ParsedScrollId.java b/server/src/main/java/org/opensearch/action/search/ParsedScrollId.java index b723b97b5c413..82009af3b0cd1 100644 --- a/server/src/main/java/org/opensearch/action/search/ParsedScrollId.java +++ b/server/src/main/java/org/opensearch/action/search/ParsedScrollId.java @@ -53,11 +53,13 @@ public class ParsedScrollId { private final String type; private final SearchContextIdForNode[] context; + private final String[] originalIndices; - ParsedScrollId(String source, String type, SearchContextIdForNode[] context) { + ParsedScrollId(String source, String type, SearchContextIdForNode[] context, String[] originalIndices) { this.source = source; this.type = type; this.context = context; + this.originalIndices = originalIndices; } public String getSource() { @@ -72,6 +74,10 @@ public SearchContextIdForNode[] getContext() { return context; } + public String[] getOriginalIndices() { + return originalIndices; + } + public boolean hasLocalIndices() { return Arrays.stream(context).anyMatch(c -> c.getClusterAlias() == null); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchScrollRequest.java b/server/src/main/java/org/opensearch/action/search/SearchScrollRequest.java index 044efdc36d04f..991c29508bd60 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchScrollRequest.java +++ b/server/src/main/java/org/opensearch/action/search/SearchScrollRequest.java @@ -36,6 +36,7 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.tasks.TaskId; @@ -61,6 +62,7 @@ public class SearchScrollRequest extends ActionRequest implements ToXContentObje private String scrollId; private Scroll scroll; + private transient ParsedScrollId parsedScrollId; public SearchScrollRequest() {} @@ -103,7 +105,20 @@ public SearchScrollRequest scrollId(String scrollId) { } public ParsedScrollId parseScrollId() { - return TransportSearchHelper.parseScrollId(scrollId); + if (parsedScrollId == null && scrollId != null) { + parsedScrollId = TransportSearchHelper.parseScrollId(scrollId); + } + return parsedScrollId; + } + + public String[] originalIndicesOrEmpty() { + try { + ParsedScrollId parsed = parseScrollId(); + String[] orig = parsed == null ? null : parsed.getOriginalIndices(); + return orig == null || orig.length == 0 ? Strings.EMPTY_ARRAY : orig; + } catch (IllegalArgumentException e) { + return Strings.EMPTY_ARRAY; + } } /** diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchHelper.java b/server/src/main/java/org/opensearch/action/search/TransportSearchHelper.java index 5c260e02e7275..779a0ac0c6590 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchHelper.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchHelper.java @@ -35,6 +35,7 @@ import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.BytesStreamInput; import org.opensearch.search.SearchPhaseResult; @@ -56,11 +57,17 @@ final class TransportSearchHelper { private static final String INCLUDE_CONTEXT_UUID = "include_context_uuid"; + public static final Version INDICES_IN_SCROLL_ID_VERSION = Version.V_3_6_0; + static InternalScrollSearchRequest internalScrollSearchRequest(ShardSearchContextId id, SearchScrollRequest request) { return new InternalScrollSearchRequest(request, id); } static String buildScrollId(AtomicArray searchPhaseResults, Version version) { + return buildScrollId(searchPhaseResults, null, version); + } + + static String buildScrollId(AtomicArray searchPhaseResults, String[] originalIndices, Version version) { try { BytesStreamOutput out = new BytesStreamOutput(); out.writeString(INCLUDE_CONTEXT_UUID); @@ -78,6 +85,13 @@ static String buildScrollId(AtomicArray searchPhase out.writeString(searchShardTarget.getNodeId()); } } + + if (version.onOrAfter(INDICES_IN_SCROLL_ID_VERSION)) { + // To keep autotagging consistent between the initial SearchRequest + // and subsequent SearchScrollRequests, we store exactly the original indices + // received during the "search" phase + out.writeOptionalStringArray(originalIndices); + } byte[] bytes = BytesReference.toBytes(out.bytes()); return Base64.getUrlEncoder().encodeToString(bytes); } catch (IOException e) { @@ -114,10 +128,13 @@ static ParsedScrollId parseScrollId(String scrollId) { } context[i] = new SearchContextIdForNode(clusterAlias, target, new ShardSearchContextId(contextUUID, id)); } + + final String[] originalIndices = in.getPosition() < bytes.length ? in.readOptionalStringArray() : Strings.EMPTY_ARRAY; + if (in.getPosition() != bytes.length) { throw new IllegalArgumentException("Not all bytes were read"); } - return new ParsedScrollId(scrollId, type, context); + return new ParsedScrollId(scrollId, type, context, originalIndices); } catch (Exception e) { throw new IllegalArgumentException("Cannot parse scroll id", e); } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java index b0f98a4c1703b..c6383acb3d767 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java @@ -79,7 +79,7 @@ protected void doExecute(Task task, SearchScrollRequest request, ActionListener< ((WorkloadGroupTask) task).setWorkloadGroupId(threadPool.getThreadContext()); } - ParsedScrollId scrollId = TransportSearchHelper.parseScrollId(request.scrollId()); + ParsedScrollId scrollId = request.parseScrollId(); Runnable action; switch (scrollId.getType()) { case ParsedScrollId.QUERY_THEN_FETCH_TYPE: diff --git a/server/src/test/java/org/opensearch/action/search/ParsedScrollIdTests.java b/server/src/test/java/org/opensearch/action/search/ParsedScrollIdTests.java index 2d90bf9ba1bdd..0985cb5308802 100644 --- a/server/src/test/java/org/opensearch/action/search/ParsedScrollIdTests.java +++ b/server/src/test/java/org/opensearch/action/search/ParsedScrollIdTests.java @@ -50,7 +50,12 @@ public void testHasLocalIndices() { new ShardSearchContextId(randomAlphaOfLength(8), randomLong()) ); } - final ParsedScrollId parsedScrollId = new ParsedScrollId(randomAlphaOfLength(8), randomAlphaOfLength(8), searchContextIdForNodes); + final ParsedScrollId parsedScrollId = new ParsedScrollId( + randomAlphaOfLength(8), + randomAlphaOfLength(8), + searchContextIdForNodes, + new String[0] + ); assertEquals(hasLocal, parsedScrollId.hasLocalIndices()); } diff --git a/server/src/test/java/org/opensearch/action/search/SearchScrollAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchScrollAsyncActionTests.java index 12ab735c4d324..f5ceef0885520 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchScrollAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchScrollAsyncActionTests.java @@ -481,7 +481,7 @@ protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearch private static ParsedScrollId getParsedScrollId(SearchContextIdForNode... idsForNodes) { List searchContextIdForNodes = Arrays.asList(idsForNodes); Collections.shuffle(searchContextIdForNodes, random()); - return new ParsedScrollId("", "test", searchContextIdForNodes.toArray(new SearchContextIdForNode[0])); + return new ParsedScrollId("", "test", searchContextIdForNodes.toArray(new SearchContextIdForNode[0]), new String[0]); } private ActionListener dummyListener() {