Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Optimize embedding generation in Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
- Optimize embedding generation in Sparse Encoding Processor ([#1246](https://github.com/opensearch-project/neural-search/pull/1246))
- Optimize embedding generation in Text/Image Embedding Processor ([#1249](https://github.com/opensearch-project/neural-search/pull/1249))
- Inner hits support with hybrid query ([#1253](https://github.com/opensearch-project/neural-search/pull/1253))

### Enhancements

Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ dependencies {
testFixturesCompileOnly group: 'com.google.guava', name: 'guava', version:'32.1.3-jre'
testFixturesImplementation fileTree(dir: knnJarDirectory, include: ["opensearch-knn-${opensearch_build}.jar", "remote-index-build-client-${opensearch_build}.jar"])
testImplementation fileTree(dir: knnJarDirectory, include: ["opensearch-knn-${opensearch_build}.jar", "remote-index-build-client-${opensearch_build}.jar"])
testImplementation "org.opensearch.plugin:parent-join-client:${opensearch_version}"
}

// In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.List;
import java.util.ListIterator;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

Expand All @@ -25,6 +26,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.InnerHitContextBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -392,4 +394,19 @@ public void visit(QueryBuilderVisitor visitor) {
subQueryBuilder.visit(subVisitor);
}
}

/**
* Extracts the inner hits from the hybrid query tree structure.
* While it extracts inner hits, child inner hits are inlined into the inner hit builder they belong to.
* This implementation handles inner hits for all sub-queries within the hybrid query.
*
* @param innerHits the map to collect inner hit contexts, where the key is the inner hit name
* and the value is the corresponding inner hit context builder
*/
@Override
protected void extractInnerHitBuilders(Map<String, InnerHitContextBuilder> innerHits) {
for (QueryBuilder queryBuilder : queries) {
InnerHitContextBuilder.extractInnerHits(queryBuilder, innerHits);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ private void prepareResourcesForNestegAggregationsScenario(String index) throws
index,
buildIndexConfiguration(
List.of(new KNNFieldConfig("location", 2, TEST_SPACE_TYPE)),
Map.of(),
List.of(),
List.of(),
List.of(FLOAT_FIELD_NAME_IMDB),
Expand Down Expand Up @@ -701,105 +702,123 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
&& !indexExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS)) {
createIndexWithConfiguration(
indexName,
buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_1), List.of(DATE_FIELD_1), 3),
buildIndexConfiguration(List.of(), Map.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_1), List.of(DATE_FIELD_1), 3),
""
);

addKnnDoc(
indexTheDocument(
indexName,
"1",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1),
List.of(),
List.of(),
Map.of(),
List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE),
List.of(INTEGER_FIELD_1_VALUE, INTEGER_FIELD_PRICE_1_VALUE),
List.of(KEYWORD_FIELD_1),
List.of(KEYWORD_FIELD_1_VALUE),
List.of(DATE_FIELD_1),
List.of(DATE_FIELD_1_VALUE)
List.of(DATE_FIELD_1_VALUE),
List.of(),
List.of(),
null
);
addKnnDoc(
indexTheDocument(
indexName,
"2",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT3),
List.of(),
List.of(),
Map.of(),
List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE),
List.of(INTEGER_FIELD_2_VALUE, INTEGER_FIELD_PRICE_2_VALUE),
List.of(),
List.of(),
List.of(DATE_FIELD_1),
List.of(DATE_FIELD_2_VALUE)
List.of(DATE_FIELD_2_VALUE),
List.of(),
List.of(),
null
);
addKnnDoc(
indexTheDocument(
indexName,
"3",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT2),
List.of(),
List.of(),
Map.of(),
List.of(INTEGER_FIELD_PRICE),
List.of(INTEGER_FIELD_PRICE_3_VALUE),
List.of(KEYWORD_FIELD_1),
List.of(KEYWORD_FIELD_2_VALUE),
List.of(DATE_FIELD_1),
List.of(DATE_FIELD_3_VALUE)
List.of(DATE_FIELD_3_VALUE),
List.of(),
List.of(),
null
);
addKnnDoc(
indexTheDocument(
indexName,
"4",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT4),
List.of(),
List.of(),
Map.of(),
List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE),
List.of(INTEGER_FIELD_3_VALUE, INTEGER_FIELD_PRICE_4_VALUE),
List.of(KEYWORD_FIELD_1),
List.of(KEYWORD_FIELD_3_VALUE),
List.of(DATE_FIELD_1),
List.of(DATE_FIELD_2_VALUE)
List.of(DATE_FIELD_2_VALUE),
List.of(),
List.of(),
null
);
addKnnDoc(
indexTheDocument(
indexName,
"5",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT5),
List.of(),
List.of(),
Map.of(),
List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE),
List.of(INTEGER_FIELD_3_VALUE, INTEGER_FIELD_PRICE_5_VALUE),
List.of(KEYWORD_FIELD_1),
List.of(KEYWORD_FIELD_4_VALUE),
List.of(DATE_FIELD_1),
List.of(DATE_FIELD_4_VALUE)
List.of(DATE_FIELD_4_VALUE),
List.of(),
List.of(),
null
);
addKnnDoc(
indexTheDocument(
indexName,
"6",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT6),
List.of(),
List.of(),
Map.of(),
List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE),
List.of(INTEGER_FIELD_4_VALUE, INTEGER_FIELD_PRICE_6_VALUE),
List.of(KEYWORD_FIELD_1),
List.of(KEYWORD_FIELD_4_VALUE),
List.of(DATE_FIELD_1),
List.of(DATE_FIELD_4_VALUE)
List.of(DATE_FIELD_4_VALUE),
List.of(),
List.of(),
null
);
}
}
Expand All @@ -809,42 +828,48 @@ private void initializeIndexWithOneShardIfNotExists(String indexName) {
if (!indexExists(indexName)) {
createIndexWithConfiguration(
indexName,
buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_1), List.of(), 1),
buildIndexConfiguration(List.of(), Map.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_1), List.of(), 1),
""
);

addKnnDoc(
indexTheDocument(
indexName,
"1",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1),
List.of(),
List.of(),
Map.of(),
List.of(INTEGER_FIELD_1),
List.of(INTEGER_FIELD_1_VALUE),
List.of(),
List.of(),
List.of(),
List.of()
List.of(),
List.of(),
List.of(),
null
);

addKnnDoc(
indexTheDocument(
indexName,
"2",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT3),
List.of(),
List.of(),
Map.of(),
List.of(INTEGER_FIELD_1),
List.of(INTEGER_FIELD_2_VALUE),
List.of(),
List.of(),
List.of(),
List.of()
List.of(),
List.of(),
List.of(),
null
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_TEXT_FIELD;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
Expand All @@ -31,6 +32,7 @@
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.join.ScoreMode;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.Version;
Expand All @@ -56,8 +58,12 @@
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.InnerHitBuilder;
import org.opensearch.index.query.InnerHitContextBuilder;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -1096,6 +1102,48 @@ public void testFilter() {
assertEquals(new MatchAllQueryBuilder(), updatedNeuralSparseQueryBuilder.filter().get(0));
}

public void testExtractInnerHitsBuilders() {
NestedQueryBuilder nestedQueryBuilder1 = new NestedQueryBuilder(
"path1",
new MatchQueryBuilder("testFieldName1", "testValue1"),
ScoreMode.Max
);
nestedQueryBuilder1.innerHit(new InnerHitBuilder());
NestedQueryBuilder nestedQueryBuilder2 = new NestedQueryBuilder(
"path2",
new MatchQueryBuilder("testFieldName2", "testValue2"),
ScoreMode.Max
);
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(nestedQueryBuilder1).add(nestedQueryBuilder2);
Map<String, InnerHitContextBuilder> innerHitsMap = new HashMap<>();
hybridQueryBuilder.extractInnerHitBuilders(innerHitsMap);
assertEquals("path1", innerHitsMap.keySet().iterator().next());
assertEquals(1, innerHitsMap.size());
}

public void testExtractInnerHitsBuilders_whenMultipleInnerHitsOnSamePath_thenFail() {
InnerHitBuilder innerHitBuilder = new InnerHitBuilder();
NestedQueryBuilder nestedQueryBuilder1 = new NestedQueryBuilder(
"path1",
new MatchQueryBuilder("testFieldName1", "testValue1"),
ScoreMode.Max
);
nestedQueryBuilder1.innerHit(innerHitBuilder);
NestedQueryBuilder nestedQueryBuilder2 = new NestedQueryBuilder(
"path1",
new MatchQueryBuilder("testFieldName1", "testValue2"),
ScoreMode.Max
);
nestedQueryBuilder2.innerHit(innerHitBuilder);
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(nestedQueryBuilder1).add(nestedQueryBuilder2);
Map<String, InnerHitContextBuilder> innerHitsMap = new HashMap<>();
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> hybridQueryBuilder.extractInnerHitBuilders(innerHitsMap)
);
assertEquals("[inner_hits] already contains an entry for key [path1]", e.getMessage());
}

private Map<String, Object> getInnerMap(Object innerObject, String queryName, String fieldName) {
if (!(innerObject instanceof Map)) {
fail("field name does not map to nested object");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ private void initializeIndexIfNotExist(String indexName) {
indexName,
buildIndexConfiguration(
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)),
List.of(TEST_NESTED_TYPE_FIELD_NAME_1),
Map.of(TEST_NESTED_TYPE_FIELD_NAME_1, Map.of()),
1
),
""
Expand All @@ -871,7 +871,7 @@ private void initializeIndexIfNotExist(String indexName) {
indexName,
buildIndexConfiguration(
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)),
List.of(),
Map.of(),
1
),
""
Expand Down
Loading
Loading