Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport-2.x] Added support for msearch API to pass search pipeline name #16085

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add successfulSearchShardIndices in searchRequestContext ([#15967](https://github.com/opensearch-project/OpenSearch/pull/15967))
- Remove identity-related feature flagged code from the RestController ([#15430](https://github.com/opensearch-project/OpenSearch/pull/15430))
- Fallback to Remote cluster-state on Term-Version check mismatch - ([#15424](https://github.com/opensearch-project/OpenSearch/pull/15424))
- Add support for msearch API to pass search pipeline name - ([#15923](https://github.com/opensearch-project/OpenSearch/pull/15923))

### Dependencies
- Bump `org.apache.logging.log4j:log4j-core` from 2.23.1 to 2.24.0 ([#15858](https://github.com/opensearch-project/OpenSearch/pull/15858))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@
) {
consumer.accept(searchRequest, parser);
}

if (searchRequest.source() != null && searchRequest.source().pipeline() != null) {
searchRequest.pipeline(searchRequest.source().pipeline());

Check warning on line 315 in server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java#L315

Added line #L315 was not covered by tests
}
// move pointers
from = nextMarker + 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ public static void parseSearchRequest(
searchRequest.routing(request.param("routing"));
searchRequest.preference(request.param("preference"));
searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions()));
searchRequest.pipeline(request.param("search_pipeline"));
searchRequest.pipeline(request.param("search_pipeline", searchRequest.source().pipeline()));

checkRestTotalHits(request, searchRequest);
request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ public static HighlightBuilder highlight() {

private Map<String, Object> searchPipelineSource = null;

private String searchPipeline;

/**
* Constructs a new search source builder.
*/
Expand Down Expand Up @@ -306,6 +308,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException {
derivedFields = in.readList(DerivedField::new);
}
}
if (in.getVersion().onOrAfter(Version.V_2_18_0)) {
searchPipeline = in.readOptionalString();
}
}

@Override
Expand Down Expand Up @@ -394,6 +399,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeList(derivedFields);
}
}
if (out.getVersion().onOrAfter(Version.V_2_18_0)) {
out.writeOptionalString(searchPipeline);
}
}

/**
Expand Down Expand Up @@ -1128,6 +1136,13 @@ public Map<String, Object> searchPipelineSource() {
return searchPipelineSource;
}

/**
* @return a search pipeline name defined within the search source (see {@link org.opensearch.search.pipeline.SearchPipelineService})
*/
public String pipeline() {
return searchPipeline;
}

/**
* Define a search pipeline to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}.
*/
Expand All @@ -1136,6 +1151,14 @@ public SearchSourceBuilder searchPipelineSource(Map<String, Object> searchPipeli
return this;
}

/**
* Define a search pipeline name to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}.
*/
public SearchSourceBuilder pipeline(String searchPipeline) {
this.searchPipeline = searchPipeline;
return this;
}

/**
* Rewrites this search source builder into its primitive form. e.g. by
* rewriting the QueryBuilder. If the builder did not change the identity
Expand Down Expand Up @@ -1233,6 +1256,7 @@ private SearchSourceBuilder shallowCopy(
rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder;
rewrittenBuilder.derivedFieldsObject = derivedFieldsObject;
rewrittenBuilder.derivedFields = derivedFields;
rewrittenBuilder.searchPipeline = searchPipeline;
return rewrittenBuilder;
}

Expand Down Expand Up @@ -1300,6 +1324,8 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th
sort(parser.text());
} else if (PROFILE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
profile = parser.booleanValue();
} else if (SEARCH_PIPELINE.match(currentFieldName, parser.getDeprecationHandler())) {
searchPipeline = parser.text();
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -1629,6 +1655,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t

}

if (searchPipeline != null) {
builder.field(SEARCH_PIPELINE.getPreferredName(), searchPipeline);
}

return builder;
}

Expand Down Expand Up @@ -1906,7 +1936,8 @@ public int hashCode() {
trackTotalHitsUpTo,
pointInTimeBuilder,
derivedFieldsObject,
derivedFields
derivedFields,
searchPipeline
);
}

Expand Down Expand Up @@ -1951,7 +1982,8 @@ public boolean equals(Object obj) {
&& Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo)
&& Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder)
&& Objects.equals(derivedFieldsObject, other.derivedFieldsObject)
&& Objects.equals(derivedFields, other.derivedFields);
&& Objects.equals(derivedFields, other.derivedFields)
&& Objects.equals(searchPipeline, other.searchPipeline);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,27 @@
import org.opensearch.geometry.LinearRing;
import org.opensearch.index.query.GeoShapeQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.search.RestSearchAction;
import org.opensearch.search.AbstractSearchTestCase;
import org.opensearch.search.Scroll;
import org.opensearch.search.builder.PointInTimeBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.VersionUtils;
import org.opensearch.test.rest.FakeRestRequest;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.IntConsumer;

import static java.util.Collections.emptyMap;
import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.opensearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;

public class SearchRequestTests extends AbstractSearchTestCase {

Expand Down Expand Up @@ -222,6 +228,19 @@ public void testCopyConstructor() throws IOException {
assertNotSame(deserializedRequest, searchRequest);
}

public void testParseSearchRequestWithUnsupportedSearchType() throws IOException {
RestRequest restRequest = new FakeRestRequest();
SearchRequest searchRequest = createSearchRequest();
IntConsumer setSize = mock(IntConsumer.class);
restRequest.params().put("search_type", "query_and_fetch");

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize)
);
assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage());
}

public void testEqualsAndHashcode() throws IOException {
checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate);
}
Expand All @@ -248,10 +267,7 @@ private SearchRequest mutate(SearchRequest searchRequest) {
);
mutators.add(
() -> mutation.searchType(
randomValueOtherThan(
searchRequest.searchType(),
() -> randomFrom(SearchType.DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH)
)
randomValueOtherThan(searchRequest.searchType(), () -> randomFrom(DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH))
)
);
mutators.add(() -> mutation.source(randomValueOtherThan(searchRequest.source(), this::createSearchSourceBuilder)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,27 @@ public void testDerivedFieldsParsingAndSerializationObjectType() throws IOExcept
}
}

public void testSearchPipelineParsingAndSerialization() throws IOException {
String restContent = "{ \"query\": { \"match_all\": {} }, \"from\": 0, \"size\": 10, \"search_pipeline\": \"my_pipeline\" }";
String expectedContent = "{\"from\":0,\"size\":10,\"query\":{\"match_all\":{\"boost\":1.0}},\"search_pipeline\":\"my_pipeline\"}";

try (XContentParser parser = createParser(JsonXContent.jsonXContent, restContent)) {
SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.fromXContent(parser);
searchSourceBuilder = rewrite(searchSourceBuilder);

try (BytesStreamOutput output = new BytesStreamOutput()) {
searchSourceBuilder.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry)) {
SearchSourceBuilder deserializedBuilder = new SearchSourceBuilder(in);
String actualContent = deserializedBuilder.toString();
assertEquals(expectedContent, actualContent);
assertEquals(searchSourceBuilder.hashCode(), deserializedBuilder.hashCode());
assertNotSame(searchSourceBuilder, deserializedBuilder);
}
}
}
}

public void testAggsParsing() throws IOException {
{
String restContent = "{\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,64 @@ public void testInlinePipeline() throws Exception {
}
}

public void testInlineDefinedPipeline() throws Exception {
SearchPipelineService searchPipelineService = createWithProcessors();

SearchPipelineMetadata metadata = new SearchPipelineMetadata(
Map.of(
"p1",
new PipelineConfiguration(
"p1",
new BytesArray(
"{"
+ "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }],"
+ "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]"
+ "}"
),
MediaTypeRegistry.JSON
)
)

);
ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build();
ClusterState previousState = clusterState;
clusterState = ClusterState.builder(clusterState)
.metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata))
.build();
searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState));

SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1");
SearchRequest searchRequest = new SearchRequest().source(sourceBuilder);
searchRequest.pipeline(searchRequest.source().pipeline());

// Verify pipeline
PipelinedRequest pipelinedRequest = syncTransformRequest(
searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver)
);
Pipeline pipeline = pipelinedRequest.getPipeline();
assertEquals("p1", pipeline.getId());
assertEquals(1, pipeline.getSearchRequestProcessors().size());
assertEquals(1, pipeline.getSearchResponseProcessors().size());

// Verify that pipeline transforms request
assertEquals(200, pipelinedRequest.source().size());

int size = 10;
SearchHit[] hits = new SearchHit[size];
for (int i = 0; i < size; i++) {
hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap());
hits[i].score(i);
}
SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size);
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);

SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse);
for (int i = 0; i < size; i++) {
assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001);
}
}

public void testInfo() {
SearchPipelineService searchPipelineService = createWithProcessors();
SearchPipelineInfo info = searchPipelineService.info();
Expand Down
Loading