diff --git a/build.gradle b/build.gradle index 04f84a525e..3247ec495e 100644 --- a/build.gradle +++ b/build.gradle @@ -315,7 +315,7 @@ opensearchplugin { // zip file name and plugin name in ${opensearch.plugin.name} read by OpenSearch when plugin loading description 'OpenSearch k-NN plugin' classname 'org.opensearch.knn.plugin.KNNPlugin' - extendedPlugins = ['lang-painless'] + extendedPlugins = ['lang-painless', 'transport-grpc'] licenseFile = rootProject.file('LICENSE.txt') noticeFile = rootProject.file('NOTICE.txt') } @@ -340,10 +340,20 @@ dependencies { api "org.opensearch:opensearch:${opensearch_version}" api project(":remote-index-build-client") compileOnly "org.opensearch.plugin:opensearch-scripting-painless-spi:${versions.opensearch}" - api group: 'com.google.guava', name: 'failureaccess', version:'1.0.1' - api group: 'com.google.guava', name: 'guava', version:'32.1.3-jre' + // TODO migrate this to a dependency to 'transport-grpc-spi' once it is ready, to avoid excluding all these unnecessary dependencies + compileOnly("org.opensearch.plugin:transport-grpc:${opensearch_version}") { + exclude group: 'com.google.guava', module: 'guava' + exclude group: 'com.google.guava', module: 'failureaccess' + exclude group: 'com.google.errorprone', module: 'error_prone_annotations' + } + compileOnly "org.opensearch:protobufs:0.6.0" + compileOnly group: 'com.google.guava', name: 'failureaccess', version:'1.0.2' + compileOnly group: 'com.google.guava', name: 'guava', version:'33.2.1-jre' api group: 'commons-lang', name: 'commons-lang', version: '2.6' testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}" + // Add Guava for test configurations since it's needed for testing but excluded from runtime + testImplementation group: 'com.google.guava', name: 'guava', version:'33.2.1-jre' + testFixturesImplementation group: 'com.google.guava', name: 'guava', version:'33.2.1-jre' testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: "${versions.bytebuddy}" testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3' testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: "${versions.bytebuddy}" diff --git a/qa/build.gradle b/qa/build.gradle index 34cd484d7e..d20505e61b 100644 --- a/qa/build.gradle +++ b/qa/build.gradle @@ -25,6 +25,8 @@ dependencies { api "org.apache.logging.log4j:log4j-api:${versions.log4j}" api "org.apache.logging.log4j:log4j-core:${versions.log4j}" + // Guava needed for BWC test compilation (ImmutableList, ImmutableMap usage) + testImplementation group: 'com.google.guava', name: 'guava', version:'33.2.1-jre' testImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation(testFixtures(rootProject)) } diff --git a/qa/restart-upgrade/build.gradle b/qa/restart-upgrade/build.gradle index 7775c89f15..4188232485 100644 --- a/qa/restart-upgrade/build.gradle +++ b/qa/restart-upgrade/build.gradle @@ -8,6 +8,11 @@ import org.apache.tools.ant.taskdefs.condition.Os apply from : "$rootDir/qa/build.gradle" +dependencies { + // Ensure Guava is available for BWC test compilation + testImplementation group: 'com.google.guava', name: 'guava', version:'33.2.1-jre' +} + String default_bwc_version = System.getProperty("bwc.version") String knn_bwc_version = System.getProperty("tests.bwc.version", default_bwc_version) boolean isSnapshot = knn_bwc_version.contains("-SNAPSHOT") @@ -137,7 +142,7 @@ testClusters { excludeTestsMatching "org.opensearch.knn.bwc.ScriptScoringIT.testNonKNNIndex_withMethodParams_withLuceneEngine" } } - + if (knn_bwc_version.startsWith("1.") || knn_bwc_version.startsWith("2.0.") || knn_bwc_version.startsWith("2.1.") || diff --git a/qa/rolling-upgrade/build.gradle b/qa/rolling-upgrade/build.gradle index 746025545b..e9ae0d5ea1 100644 --- a/qa/rolling-upgrade/build.gradle +++ b/qa/rolling-upgrade/build.gradle @@ -8,6 +8,11 @@ import org.apache.tools.ant.taskdefs.condition.Os apply from : "$rootDir/qa/build.gradle" +dependencies { + // Ensure Guava is available for BWC test compilation + testImplementation group: 'com.google.guava', name: 'guava', version:'33.2.1-jre' +} + String default_bwc_version = System.getProperty("bwc.version") String knn_bwc_version = System.getProperty("tests.bwc.version", default_bwc_version) boolean isSnapshot = knn_bwc_version.contains("-SNAPSHOT") @@ -106,4 +111,3 @@ task testRollingUpgrade(type: StandaloneRestIntegTestTask) { nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' } - diff --git a/release-notes/opensearch-k-NN.release-notes-3.2.0.0.md b/release-notes/opensearch-k-NN.release-notes-3.2.0.0.md index 2375499a87..f55105da22 100644 --- a/release-notes/opensearch-k-NN.release-notes-3.2.0.0.md +++ b/release-notes/opensearch-k-NN.release-notes-3.2.0.0.md @@ -12,6 +12,7 @@ Compatible with OpenSearch 3.2.0 * Support GPU indexing for FP16, Byte and Binary [#2819](https://github.com/opensearch-project/k-NN/pull/2819) * Add random rotation feature to binary encoder for improving recall on certain datasets [#2718](https://github.com/opensearch-project/k-NN/pull/2718) * Asymmetric Distance Computation (ADC) for binary quantized faiss indices [#2733](https://github.com/opensearch-project/k-NN/pull/2733) +* Extend transport-grpc module to support GRPC KNN queries [#2817](https://github.com/opensearch-project/k-NN/pull/2817) ### Enhancements * Add KNN timing info to core profiler [#2785](https://github.com/opensearch-project/k-NN/pull/2785) diff --git a/src/main/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoConverter.java b/src/main/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoConverter.java new file mode 100644 index 0000000000..8dc7c3e216 --- /dev/null +++ b/src/main/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoConverter.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.grpc.proto.request.search.query; + +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.transport.grpc.proto.request.search.query.QueryBuilderProtoConverter; +import org.opensearch.protobufs.QueryContainer; + +/** + * Converter for KNN queries. + * This class implements the QueryBuilderProtoConverter interface to provide KNN query support + * for the gRPC transport plugin. + */ +public class KNNQueryBuilderProtoConverter implements QueryBuilderProtoConverter { + + @Override + public QueryContainer.QueryContainerCase getHandledQueryCase() { + return QueryContainer.QueryContainerCase.KNN; + } + + @Override + public QueryBuilder fromProto(QueryContainer queryContainer) { + if (queryContainer == null || queryContainer.getQueryContainerCase() != QueryContainer.QueryContainerCase.KNN) { + throw new IllegalArgumentException("QueryContainer does not contain a KNN query"); + } + + return KNNQueryBuilderProtoUtils.fromProto(queryContainer.getKnn()); + } +} diff --git a/src/main/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoUtils.java b/src/main/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoUtils.java new file mode 100644 index 0000000000..46738122a7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoUtils.java @@ -0,0 +1,226 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.grpc.proto.request.search.query; + +import lombok.experimental.UtilityClass; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; +import org.opensearch.knn.index.query.request.MethodParameter; +import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.transport.grpc.proto.request.search.query.QueryBuilderProtoConverterRegistry; +import org.opensearch.protobufs.KnnQuery; +import org.opensearch.protobufs.KnnQueryRescore; +import org.opensearch.protobufs.QueryContainer; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Utility class for converting KNN Protocol Buffers to OpenSearch objects. + * This class provides methods to transform Protocol Buffer representations of KNN queries + * into their corresponding OpenSearch KNNQueryBuilder implementations for search operations. + */ +@UtilityClass +public class KNNQueryBuilderProtoUtils { + + // Registry for query conversion + private static QueryBuilderProtoConverterRegistry REGISTRY = new QueryBuilderProtoConverterRegistry(); + + /** + * Sets the registry for testing purposes. + * + * @param registry The registry to use + */ + void setRegistry(QueryBuilderProtoConverterRegistry registry) { + REGISTRY = registry; + } + + /** + * Gets the current registry. + * + * @return The current registry + */ + QueryBuilderProtoConverterRegistry getRegistry() { + return REGISTRY; + } + + /** + * Converts a Protocol Buffer KnnQuery to an OpenSearch KNNQueryBuilder. + * This method follows the exact same pattern as {@link KNNQueryBuilderParser#fromXContent(XContentParser)} + * to ensure parsing consistency and compatibility. + * + * @param knnQueryProto The Protocol Buffer KnnQuery to convert + * @return A configured KNNQueryBuilder instance + */ + public QueryBuilder fromProto(KnnQuery knnQueryProto) { + // Create builder using the internal parser pattern like XContent parsing + KNNQueryBuilder.Builder builder = KNNQueryBuilder.builder(); + + // Set field name (equivalent to fieldName parsing in XContent) + builder.fieldName(knnQueryProto.getField()); + + // Set vector (equivalent to VECTOR_FIELD parsing) + builder.vector(convertVector(knnQueryProto.getVectorList())); + + // Set k if present (equivalent to K_FIELD parsing) + if (knnQueryProto.getK() > 0) { + builder.k(knnQueryProto.getK()); + } + + // Set maxDistance if present (equivalent to MAX_DISTANCE_FIELD parsing) + else if (knnQueryProto.hasMaxDistance()) { + builder.maxDistance(knnQueryProto.getMaxDistance()); + } + + // Set minScore if present (equivalent to MIN_SCORE_FIELD parsing) + else if (knnQueryProto.hasMinScore()) { + builder.minScore(knnQueryProto.getMinScore()); + } + + // Set method parameters (equivalent to METHOD_PARAMS_FIELD parsing) + if (knnQueryProto.hasMethodParameters()) { + Map methodParameters = convertMethodParameters(knnQueryProto.getMethodParameters()); + builder.methodParameters(methodParameters); + } + + // Set filter (equivalent to FILTER_FIELD parsing) + if (knnQueryProto.hasFilter()) { + QueryContainer filterQueryContainer = knnQueryProto.getFilter(); + builder.filter(REGISTRY.fromProto(filterQueryContainer)); + } + + // Set rescore (equivalent to RESCORE_FIELD parsing) + if (knnQueryProto.hasRescore()) { + RescoreContext rescoreContext = convertRescoreContext(knnQueryProto.getRescore()); + builder.rescoreContext(rescoreContext); + } + + // Set boost (equivalent to BOOST_FIELD parsing) + if (knnQueryProto.hasBoost()) { + builder.boost(knnQueryProto.getBoost()); + } + + // Set query name (equivalent to NAME_FIELD parsing) + if (knnQueryProto.hasUnderscoreName()) { + builder.queryName(knnQueryProto.getUnderscoreName()); + } + + // Set expandNested (equivalent to EXPAND_NESTED_FIELD parsing) + if (knnQueryProto.hasExpandNestedDocs()) { + builder.expandNested(knnQueryProto.getExpandNestedDocs()); + } + + return builder.build(); + } + + /** + * Converts a Protocol Buffer vector list to a float array. + * + * @param vectorList The Protocol Buffer vector list + * @return The converted float array + */ + private float[] convertVector(List vectorList) { + float[] vector = new float[vectorList.size()]; + for (int i = 0; i < vectorList.size(); i++) { + vector[i] = vectorList.get(i); + } + return vector; + } + + /** + * Converts Protocol Buffer method parameters following the exact same pattern as + * {@link MethodParametersParser#fromXContent(XContentParser)} to ensure consistency. + * + * @param objectMap The Protocol Buffer ObjectMap to convert + * @return The converted method parameters Map + */ + private Map convertMethodParameters(org.opensearch.protobufs.ObjectMap objectMap) { + // First convert Protocol Buffer to raw Map (equivalent to parser.map()) + Map rawMethodParameters = new HashMap<>(); + for (Map.Entry entry : objectMap.getFieldsMap().entrySet()) { + String key = entry.getKey(); + Object value = convertObjectMapValue(entry.getValue()); + if (value != null) { + rawMethodParameters.put(key, value); + } + } + + // Then process through MethodParameter.parse() exactly like XContent parsing does + Map processedMethodParameters = new HashMap<>(); + for (Map.Entry entry : rawMethodParameters.entrySet()) { + String name = entry.getKey(); + Object value = entry.getValue(); + + // Find the MethodParameter enum (same as XContent parsing) + MethodParameter parameter = MethodParameter.enumOf(name); + if (parameter == null) { + throw new IllegalArgumentException("unknown method parameter found [" + name + "]"); + } + + try { + // Parse using MethodParameter.parse() - this handles type conversion properly + Object parsedValue = parameter.parse(value); + processedMethodParameters.put(name, parsedValue); + } catch (Exception exception) { + throw new IllegalArgumentException("Error parsing method parameter [" + name + "]: " + exception.getMessage()); + } + } + + return processedMethodParameters.isEmpty() ? null : processedMethodParameters; + } + + /** + * Converts a Protocol Buffer ObjectMap.Value to a Java Object. + * + * @param value The Protocol Buffer Value to convert + * @return The converted Java Object, or null if unsupported type + */ + private Object convertObjectMapValue(org.opensearch.protobufs.ObjectMap.Value value) { + switch (value.getValueCase()) { + case INT32: + return value.getInt32(); + case INT64: + return value.getInt64(); + case FLOAT: + return value.getFloat(); + case DOUBLE: + return value.getDouble(); + case STRING: + return value.getString(); + case BOOL: + return value.getBool(); + default: + // Skip unsupported types + return null; + } + } + + /** + * Converts a Protocol Buffer KnnQueryRescore to a RescoreContext. + * + * @param rescoreProto The Protocol Buffer KnnQueryRescore to convert + * @return The converted RescoreContext + */ + private RescoreContext convertRescoreContext(KnnQueryRescore rescoreProto) { + switch (rescoreProto.getKnnQueryRescoreCase()) { + case ENABLE: + return rescoreProto.getEnable() ? RescoreContext.getDefault() : RescoreContext.EXPLICITLY_DISABLED_RESCORE_CONTEXT; + + case CONTEXT: + org.opensearch.protobufs.RescoreContext contextProto = rescoreProto.getContext(); + return contextProto.hasOversampleFactor() + ? RescoreContext.builder().oversampleFactor(contextProto.getOversampleFactor()).build() + : RescoreContext.getDefault(); + + default: + return RescoreContext.getDefault(); + } + } + +} diff --git a/src/main/resources/META-INF/services/org.opensearch.transport.grpc.proto.request.search.query.QueryBuilderProtoConverter b/src/main/resources/META-INF/services/org.opensearch.transport.grpc.proto.request.search.query.QueryBuilderProtoConverter new file mode 100644 index 0000000000..78475875cc --- /dev/null +++ b/src/main/resources/META-INF/services/org.opensearch.transport.grpc.proto.request.search.query.QueryBuilderProtoConverter @@ -0,0 +1 @@ +org.opensearch.knn.grpc.proto.request.search.query.KNNQueryBuilderProtoConverter diff --git a/src/test/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoConverterTests.java b/src/test/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoConverterTests.java new file mode 100644 index 0000000000..7105af93f2 --- /dev/null +++ b/src/test/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoConverterTests.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.grpc.proto.request.search.query; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.protobufs.KnnQuery; +import org.opensearch.protobufs.QueryContainer; +import org.opensearch.test.OpenSearchTestCase; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KNNQueryBuilderProtoConverterTests extends OpenSearchTestCase { + + private KNNQueryBuilderProtoConverter converter; + private QueryContainer queryContainer; + private KnnQuery knnQuery; + + @Before + public void setup() { + converter = new KNNQueryBuilderProtoConverter(); + queryContainer = mock(QueryContainer.class); + knnQuery = mock(KnnQuery.class); + } + + @Test + public void testGetHandledQueryCase() { + assertEquals(QueryContainer.QueryContainerCase.KNN, converter.getHandledQueryCase()); + } + + @Test + public void testFromProto_validQuery() { + // Setup + when(queryContainer.getQueryContainerCase()).thenReturn(QueryContainer.QueryContainerCase.KNN); + when(queryContainer.getKnn()).thenReturn(knnQuery); + + // Mock the KnnQuery to provide required fields + when(knnQuery.getField()).thenReturn("test_field"); + when(knnQuery.getVectorList()).thenReturn(java.util.Arrays.asList(1.0f, 2.0f, 3.0f)); + when(knnQuery.getK()).thenReturn(10); + + // Mock optional fields that may be checked + when(knnQuery.hasMaxDistance()).thenReturn(false); + when(knnQuery.hasMinScore()).thenReturn(false); + when(knnQuery.hasMethodParameters()).thenReturn(false); + when(knnQuery.hasFilter()).thenReturn(false); + when(knnQuery.hasRescore()).thenReturn(false); + when(knnQuery.hasBoost()).thenReturn(false); + when(knnQuery.hasUnderscoreName()).thenReturn(false); + when(knnQuery.hasExpandNestedDocs()).thenReturn(false); + + // Test + QueryBuilder result = converter.fromProto(queryContainer); + + // Verify + assertNotNull(result); + assertTrue(result instanceof KNNQueryBuilder); + } + + @Test + public void testFromProto_nullQueryContainer() { + // Test + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> converter.fromProto(null)); + + // Verify + assertEquals("QueryContainer does not contain a KNN query", exception.getMessage()); + } + + @Test + public void testFromProto_wrongQueryContainerCase() { + // Setup + when(queryContainer.getQueryContainerCase()).thenReturn(QueryContainer.QueryContainerCase.BOOL); + + // Test + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> converter.fromProto(queryContainer) + ); + + // Verify + assertEquals("QueryContainer does not contain a KNN query", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoUtilsTests.java b/src/test/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoUtilsTests.java new file mode 100644 index 0000000000..e82fac5b5a --- /dev/null +++ b/src/test/java/org/opensearch/knn/grpc/proto/request/search/query/KNNQueryBuilderProtoUtilsTests.java @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.grpc.proto.request.search.query; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.transport.grpc.proto.request.search.query.QueryBuilderProtoConverter; +import org.opensearch.transport.grpc.proto.request.search.query.QueryBuilderProtoConverterRegistry; +import org.opensearch.protobufs.KnnQuery; +import org.opensearch.protobufs.KnnQueryRescore; +import org.opensearch.protobufs.ObjectMap; +import org.opensearch.protobufs.QueryContainer; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class KNNQueryBuilderProtoUtilsTests extends OpenSearchTestCase { + + @Mock + private QueryBuilderProtoConverterRegistry mockRegistry; + + @Mock + private QueryBuilderProtoConverter mockConverter; + + @Mock + private QueryBuilder mockQueryBuilder; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + } + + @Test + public void testFromProto_basicFields() { + KnnQuery knnQuery = KnnQuery.newBuilder().setField("test_field").addVector(1.0f).addVector(2.0f).addVector(3.0f).setK(5).build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertEquals("test_field", knnQueryBuilder.fieldName()); + assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, (float[]) knnQueryBuilder.vector(), 0.001f); + assertEquals(5, knnQueryBuilder.getK()); + } + + @Test + public void testFromProto_withBoost() { + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setBoost(2.5f) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertEquals(2.5f, knnQueryBuilder.boost(), 0.001f); + } + + @Test + public void testFromProto_withMaxDistance() { + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setMaxDistance(0.75f) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertEquals(0.75f, knnQueryBuilder.getMaxDistance(), 0.001f); + } + + @Test + public void testFromProto_withMinScore() { + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setMinScore(0.85f) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertEquals(0.85f, knnQueryBuilder.getMinScore(), 0.001f); + } + + @Test + public void testFromProto_withQueryName() { + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setUnderscoreName("test_query") + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertEquals("test_query", knnQueryBuilder.queryName()); + } + + @Test + public void testFromProto_withExpandNested() { + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setExpandNestedDocs(true) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertTrue(knnQueryBuilder.getExpandNested()); + } + + @Test + public void testFromProto_withMethodParameters() { + ObjectMap.Value efSearchValue = ObjectMap.Value.newBuilder().setInt32(100).build(); + ObjectMap.Value nprobesValue = ObjectMap.Value.newBuilder().setInt32(10).build(); + ObjectMap methodParams = ObjectMap.newBuilder().putFields("ef_search", efSearchValue).putFields("nprobes", nprobesValue).build(); + + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setMethodParameters(methodParams) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + Map methodParameters = knnQueryBuilder.getMethodParameters(); + assertNotNull(methodParameters); + assertEquals(2, methodParameters.size()); + assertEquals(100, methodParameters.get("ef_search")); + assertEquals(10, methodParameters.get("nprobes")); + } + + @Test + public void testFromProto_withFilter() { + QueryContainer filterContainer = QueryContainer.newBuilder().build(); + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setFilter(filterContainer) + .build(); + + QueryBuilderProtoConverterRegistry originalRegistry = KNNQueryBuilderProtoUtils.getRegistry(); + + try { + KNNQueryBuilderProtoUtils.setRegistry(mockRegistry); + when(mockRegistry.fromProto(any())).thenReturn(mockQueryBuilder); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertNotNull(knnQueryBuilder.getFilter()); + assertEquals(mockQueryBuilder, knnQueryBuilder.getFilter()); + } finally { + KNNQueryBuilderProtoUtils.setRegistry(originalRegistry); + } + } + + @Test + public void testFromProto_withRescoreEnable() { + KnnQueryRescore rescore = KnnQueryRescore.newBuilder().setEnable(true).build(); + + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setRescore(rescore) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertNotNull(knnQueryBuilder.getRescoreContext()); + assertEquals(RescoreContext.getDefault(), knnQueryBuilder.getRescoreContext()); + } + + @Test + public void testFromProto_withRescoreDisable() { + KnnQueryRescore rescore = KnnQueryRescore.newBuilder().setEnable(false).build(); + + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setRescore(rescore) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertNotNull(knnQueryBuilder.getRescoreContext()); + assertEquals(RescoreContext.EXPLICITLY_DISABLED_RESCORE_CONTEXT, knnQueryBuilder.getRescoreContext()); + } + + @Test + public void testFromProto_withRescoreContext() { + org.opensearch.protobufs.RescoreContext rescoreContext = org.opensearch.protobufs.RescoreContext.newBuilder() + .setOversampleFactor(3.5f) + .build(); + + KnnQueryRescore rescore = KnnQueryRescore.newBuilder().setContext(rescoreContext).build(); + + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setRescore(rescore) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertNotNull(knnQueryBuilder.getRescoreContext()); + assertEquals(3.5f, knnQueryBuilder.getRescoreContext().getOversampleFactor(), 0.001f); + } + + @Test + public void testFromProto_withRescoreContextNoOversampleFactor() { + org.opensearch.protobufs.RescoreContext rescoreContext = org.opensearch.protobufs.RescoreContext.newBuilder().build(); + + KnnQueryRescore rescore = KnnQueryRescore.newBuilder().setContext(rescoreContext).build(); + + KnnQuery knnQuery = KnnQuery.newBuilder() + .setField("test_field") + .addVector(1.0f) + .addVector(2.0f) + .addVector(3.0f) + .setK(5) + .setRescore(rescore) + .build(); + + QueryBuilder result = KNNQueryBuilderProtoUtils.fromProto(knnQuery); + assertTrue(result instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) result; + assertNotNull(knnQueryBuilder.getRescoreContext()); + assertEquals(RescoreContext.getDefault(), knnQueryBuilder.getRescoreContext()); + } +}