Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,50 @@ public void testDeleteRuleForNonexistentId() throws Exception {
assertTrue("Expected error message for nonexistent rule ID", exception.getMessage().contains("no such index"));
}

public void testScrollRequestsAreAlsoTagged() throws Exception {
String workloadGroupId = "wlm_auto_tag_scroll";
String ruleId = "wlm_auto_tag_scroll_rule";
String indexName = "scroll_tagged_index";

setWlmMode("enabled");

WorkloadGroup workloadGroup = createWorkloadGroup("scroll_tagging_group", workloadGroupId);
updateWorkloadGroupInClusterState(PUT, workloadGroup);

FeatureType featureType = AutoTaggingRegistry.getFeatureType(WorkloadGroupFeatureType.NAME);
createRule(ruleId, "scroll tagging rule", indexName, featureType, workloadGroupId);

indexDocument(indexName);

assertBusy(() -> {
int completionsBefore = getCompletions(workloadGroupId);

SearchResponse initial = client().prepareSearch(indexName)
.setQuery(QueryBuilders.matchAllQuery())
.setScroll(TimeValue.timeValueMinutes(1))
.setSize(1)
.get();
String scrollId = initial.getScrollId();
assertNotNull("scrollId must not be null", scrollId);

try {
int afterInitialSearch = getCompletions(workloadGroupId);
assertTrue("Expected completions to increase after initial search with scroll", afterInitialSearch > completionsBefore);

SearchResponse scrollResp = client().prepareSearchScroll(scrollId).setScroll(TimeValue.timeValueMinutes(1)).get();
String nextScrollId = scrollResp.getScrollId();
if (nextScrollId != null && !nextScrollId.isEmpty()) {
scrollId = nextScrollId;
}

int afterScroll = getCompletions(workloadGroupId);
assertTrue("Expected completions to increase after scroll request as well", afterScroll > afterInitialSearch);
} finally {
clearScroll(scrollId);
}
});
}

// Helper functions
private void createRule(String ruleId, String ruleName, String indexPattern, FeatureType featureType, String workloadGroupId)
throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.IndicesRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchScrollRequest;
import org.opensearch.action.support.ActionFilter;
import org.opensearch.action.support.ActionFilterChain;
import org.opensearch.action.support.ActionRequestMetadata;
Expand All @@ -19,6 +20,7 @@
import org.opensearch.plugin.wlm.rule.attribute_extractor.IndicesExtractor;
import org.opensearch.plugin.wlm.spi.AttributeExtractorExtension;
import org.opensearch.rule.InMemoryRuleProcessingService;
import org.opensearch.rule.RuleAttribute;
import org.opensearch.rule.attribute_extractor.AttributeExtractor;
import org.opensearch.rule.autotagging.Attribute;
import org.opensearch.rule.autotagging.FeatureType;
Expand All @@ -28,6 +30,7 @@
import org.opensearch.wlm.WorkloadGroupTask;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -80,14 +83,39 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
ActionListener<Response> listener,
ActionFilterChain<Request, Response> chain
) {
final boolean isValidRequest = request instanceof SearchRequest;
final boolean isSearchRequest = request instanceof SearchRequest;
final boolean isSearchScrollRequest = request instanceof SearchScrollRequest;
final boolean isValidRequest = isSearchRequest || isSearchScrollRequest;

if (!isValidRequest || wlmClusterSettingValuesProvider.getWlmMode() == WlmMode.DISABLED) {
chain.proceed(task, action, request, listener);
return;
}
List<AttributeExtractor<String>> attributeExtractors = new ArrayList<>();
attributeExtractors.add(new IndicesExtractor((IndicesRequest) request));
if (isSearchRequest) {
attributeExtractors.add(new IndicesExtractor((IndicesRequest) request));
} else {
// Scroll: recover the original user-provided indices from ParsedScrollId
final String[] originalIndices = ((SearchScrollRequest) request).originalIndicesOrEmpty();
if (originalIndices.length > 0) {
attributeExtractors.add(new AttributeExtractor<>() {
@Override
public Attribute getAttribute() {
return RuleAttribute.INDEX_PATTERN;
}

@Override
public Iterable<String> extract() {
return Arrays.asList(originalIndices);
}

@Override
public LogicalOperator getLogicalOperator() {
return LogicalOperator.AND;
}
});
}
}

if (featureType.getAllowedAttributesRegistry().containsKey(PRINCIPAL_ATTRIBUTE_NAME)) {
Attribute attribute = featureType.getAllowedAttributesRegistry().get(PRINCIPAL_ATTRIBUTE_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchScrollRequest;
import org.opensearch.action.support.ActionFilterChain;
import org.opensearch.action.support.ActionRequestMetadata;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.rule.InMemoryRuleProcessingService;
import org.opensearch.rule.RuleAttribute;
import org.opensearch.rule.attribute_extractor.AttributeExtractor;
import org.opensearch.rule.autotagging.Attribute;
import org.opensearch.rule.autotagging.FeatureType;
import org.opensearch.rule.storage.AttributeValueStoreFactory;
Expand All @@ -29,11 +32,15 @@
import org.opensearch.wlm.WorkloadGroupTask;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -93,6 +100,39 @@ public void testApplyForInValidRequest() {
verify(ruleProcessingService, times(0)).evaluateLabel(anyList());
}

public void testApplyForScrollRequestWithOriginalIndices() {
SearchScrollRequest request = mock(SearchScrollRequest.class);
ActionFilterChain<ActionRequest, ActionResponse> chain = mock(TestActionFilterChain.class);

@SuppressWarnings("unchecked")
ActionRequestMetadata<ActionRequest, ActionResponse> metadata = mock(ActionRequestMetadata.class);
when(request.originalIndicesOrEmpty()).thenReturn(new String[] { "logs-scroll-index" });

try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
doAnswer(inv -> {
@SuppressWarnings("unchecked")
List<AttributeExtractor<String>> extractors = inv.getArgument(0);

assertNotNull(extractors);
assertEquals(1, extractors.size());

AttributeExtractor<String> ex = extractors.get(0);
assertEquals(RuleAttribute.INDEX_PATTERN, ex.getAttribute());

List<String> values = new ArrayList<>();
ex.extract().forEach(values::add);
assertEquals(List.of("logs-scroll-index"), values);

return Optional.of("ScrollQG_ID");
}).when(ruleProcessingService).evaluateLabel(any());

autoTaggingActionFilter.apply(mock(Task.class), "Test", request, metadata, null, chain);

assertEquals("ScrollQG_ID", threadPool.getThreadContext().getHeader(WorkloadGroupTask.WORKLOAD_GROUP_ID_HEADER));
verify(ruleProcessingService, times(1)).evaluateLabel(anyList());
}
}

public enum WLMFeatureType implements FeatureType {
WLM;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,9 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At
raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures));
} else {
final Version minNodeVersion = clusterState.nodes().getMinNodeVersion();
final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) : null;
final String scrollId = request.scroll() != null
? TransportSearchHelper.buildScrollId(queryResults, request.indices(), minNodeVersion)
: null;
final String searchContextId;
if (buildPointInTimeFromSearchResults()) {
searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minNodeVersion);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ public class ParsedScrollId {
private final String type;

private final SearchContextIdForNode[] context;
private final String[] originalIndices;

ParsedScrollId(String source, String type, SearchContextIdForNode[] context) {
ParsedScrollId(String source, String type, SearchContextIdForNode[] context, String[] originalIndices) {
this.source = source;
this.type = type;
this.context = context;
this.originalIndices = originalIndices;
}

public String getSource() {
Expand All @@ -72,6 +74,10 @@ public SearchContextIdForNode[] getContext() {
return context;
}

public String[] getOriginalIndices() {
return originalIndices;
}

public boolean hasLocalIndices() {
return Arrays.stream(context).anyMatch(c -> c.getClusterAlias() == null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.tasks.TaskId;
Expand All @@ -61,6 +62,7 @@ public class SearchScrollRequest extends ActionRequest implements ToXContentObje

private String scrollId;
private Scroll scroll;
private transient ParsedScrollId parsedScrollId;

public SearchScrollRequest() {}

Expand Down Expand Up @@ -103,7 +105,20 @@ public SearchScrollRequest scrollId(String scrollId) {
}

public ParsedScrollId parseScrollId() {
return TransportSearchHelper.parseScrollId(scrollId);
if (parsedScrollId == null && scrollId != null) {
parsedScrollId = TransportSearchHelper.parseScrollId(scrollId);
}
return parsedScrollId;
}

public String[] originalIndicesOrEmpty() {
try {
ParsedScrollId parsed = parseScrollId();
String[] orig = parsed == null ? null : parsed.getOriginalIndices();
return orig == null || orig.length == 0 ? Strings.EMPTY_ARRAY : orig;
} catch (IllegalArgumentException e) {
return Strings.EMPTY_ARRAY;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.Version;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.search.SearchPhaseResult;
Expand All @@ -56,11 +57,17 @@ final class TransportSearchHelper {

private static final String INCLUDE_CONTEXT_UUID = "include_context_uuid";

public static final Version INDICES_IN_SCROLL_ID_VERSION = Version.V_3_4_0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need to be V_3_6_0


static InternalScrollSearchRequest internalScrollSearchRequest(ShardSearchContextId id, SearchScrollRequest request) {
return new InternalScrollSearchRequest(request, id);
}

static String buildScrollId(AtomicArray<? extends SearchPhaseResult> searchPhaseResults, Version version) {
return buildScrollId(searchPhaseResults, null, version);
}

static String buildScrollId(AtomicArray<? extends SearchPhaseResult> searchPhaseResults, String[] originalIndices, Version version) {
try {
BytesStreamOutput out = new BytesStreamOutput();
out.writeString(INCLUDE_CONTEXT_UUID);
Expand All @@ -78,6 +85,13 @@ static String buildScrollId(AtomicArray<? extends SearchPhaseResult> searchPhase
out.writeString(searchShardTarget.getNodeId());
}
}

if (version.onOrAfter(INDICES_IN_SCROLL_ID_VERSION)) {
// To keep autotagging consistent between the initial SearchRequest
// and subsequent SearchScrollRequests, we store exactly the original indices
// received during the "search" phase
out.writeOptionalStringArray(originalIndices);
}
byte[] bytes = BytesReference.toBytes(out.bytes());
return Base64.getUrlEncoder().encodeToString(bytes);
} catch (IOException e) {
Expand Down Expand Up @@ -114,10 +128,13 @@ static ParsedScrollId parseScrollId(String scrollId) {
}
context[i] = new SearchContextIdForNode(clusterAlias, target, new ShardSearchContextId(contextUUID, id));
}

final String[] originalIndices = in.getPosition() < bytes.length ? in.readOptionalStringArray() : Strings.EMPTY_ARRAY;

if (in.getPosition() != bytes.length) {
throw new IllegalArgumentException("Not all bytes were read");
}
return new ParsedScrollId(scrollId, type, context);
return new ParsedScrollId(scrollId, type, context, originalIndices);
} catch (Exception e) {
throw new IllegalArgumentException("Cannot parse scroll id", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ protected void doExecute(Task task, SearchScrollRequest request, ActionListener<
((WorkloadGroupTask) task).setWorkloadGroupId(threadPool.getThreadContext());
}

ParsedScrollId scrollId = TransportSearchHelper.parseScrollId(request.scrollId());
ParsedScrollId scrollId = request.parseScrollId();
Runnable action;
switch (scrollId.getType()) {
case ParsedScrollId.QUERY_THEN_FETCH_TYPE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ public void testHasLocalIndices() {
new ShardSearchContextId(randomAlphaOfLength(8), randomLong())
);
}
final ParsedScrollId parsedScrollId = new ParsedScrollId(randomAlphaOfLength(8), randomAlphaOfLength(8), searchContextIdForNodes);
final ParsedScrollId parsedScrollId = new ParsedScrollId(
randomAlphaOfLength(8),
randomAlphaOfLength(8),
searchContextIdForNodes,
new String[0]
);

assertEquals(hasLocal, parsedScrollId.hasLocalIndices());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearch
private static ParsedScrollId getParsedScrollId(SearchContextIdForNode... idsForNodes) {
List<SearchContextIdForNode> searchContextIdForNodes = Arrays.asList(idsForNodes);
Collections.shuffle(searchContextIdForNodes, random());
return new ParsedScrollId("", "test", searchContextIdForNodes.toArray(new SearchContextIdForNode[0]));
return new ParsedScrollId("", "test", searchContextIdForNodes.toArray(new SearchContextIdForNode[0]), new String[0]);
}

private ActionListener<SearchResponse> dummyListener() {
Expand Down
Loading