diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java new file mode 100644 index 0000000000000..c8812cfc109f2 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java @@ -0,0 +1,238 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.query; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType; +import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.script.MockScriptPlugin; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.hamcrest.Matchers.equalTo; + +public class RescoreKnnVectorQueryIT extends ESIntegTestCase { + + public static final String INDEX_NAME = "test"; + public static final String VECTOR_FIELD = "vector"; + public static final String VECTOR_SCORE_SCRIPT = "vector_scoring"; + public static final String QUERY_VECTOR_PARAM = "query_vector"; + + @Override + protected Collection> nodePlugins() { + return Collections.singleton(CustomScriptPlugin.class); + } + + public static class CustomScriptPlugin extends MockScriptPlugin { + private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM + .vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT); + + @Override + protected Map, Object>> pluginScripts() { + return Map.of(VECTOR_SCORE_SCRIPT, vars -> { + Map doc = (Map) vars.get("doc"); + return SIMILARITY_FUNCTION.compare( + ((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(), + (float[]) vars.get(QUERY_VECTOR_PARAM) + ); + }); + } + } + + @Before + public void setup() throws IOException { + String type = randomFrom( + Arrays.stream(VectorIndexType.values()) + .filter(VectorIndexType::isQuantized) + .map(t -> t.name().toLowerCase(Locale.ROOT)) + .collect(Collectors.toCollection(ArrayList::new)) + ); + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(VECTOR_FIELD) + .field("type", "dense_vector") + .field("similarity", "l2_norm") + .startObject("index_options") + .field("type", type) + .endObject() + .endObject() + .endObject() + .endObject(); + + Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)) + .build(); + prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get(); + ensureGreen(INDEX_NAME); + } + + private record TestParams( + int numDocs, + int numDims, + float[] queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder + ) { + public static TestParams generate() { + int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions + int numDocs = randomIntBetween(10, 100); + int k = randomIntBetween(1, numDocs - 5); + return new TestParams( + numDocs, + numDims, + randomVector(numDims), + k, + (int) (k * randomFloatBetween(1.0f, 10.0f, true)), + new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true)) + ); + } + } + + public void testKnnSearchRescore() { + BiFunction knnSearchGenerator = (testParams, requestBuilder) -> { + KnnSearchBuilder knnSearch = new KnnSearchBuilder( + VECTOR_FIELD, + testParams.queryVector, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setKnnSearch(List.of(knnSearch)); + }; + testKnnRescore(knnSearchGenerator); + } + + public void testKnnQueryRescore() { + BiFunction knnQueryGenerator = (testParams, requestBuilder) -> { + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder( + VECTOR_FIELD, + testParams.queryVector, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setQuery(knnQuery); + }; + testKnnRescore(knnQueryGenerator); + } + + public void testKnnRetriever() { + BiFunction knnQueryGenerator = (testParams, requestBuilder) -> { + KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( + VECTOR_FIELD, + testParams.queryVector, + null, + testParams.k, + testParams.numCands, + testParams.rescoreVectorBuilder, + null + ); + return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever)); + }; + testKnnRescore(knnQueryGenerator); + } + + private void testKnnRescore(BiFunction searchRequestGenerator) { + TestParams testParams = TestParams.generate(); + + int numDocs = testParams.numDocs; + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + + for (int i = 0; i < numDocs; i++) { + docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims)); + } + indexRandom(true, docs); + + float[] queryVector = testParams.queryVector; + float oversample = randomFloatBetween(1.0f, 100f, true); + RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample); + + SearchRequestBuilder requestBuilder = searchRequestGenerator.apply( + testParams, + prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean()) + ); + + assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); }); + } + + private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) { + // Do an exact query and compare + Script script = new Script( + ScriptType.INLINE, + CustomScriptPlugin.NAME, + VECTOR_SCORE_SCRIPT, + Map.of(QUERY_VECTOR_PARAM, queryVector) + ); + ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script); + assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> { + assertHitCount(exactResponse, docCount); + + int i = 0; + SearchHit[] exactHits = exactResponse.getHits().getHits(); + for (SearchHit knnHit : knnResponse.getHits().getHits()) { + while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) { + i++; + } + if (i >= exactHits.length) { + fail("Knn doc not found in exact search"); + } + assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore())); + } + }); + } + + private static float[] randomVector(int numDimensions) { + float[] vector = new float[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + vector[j] = randomFloatBetween(0, 1, true); + } + return vector; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index ce41c2164e205..193b2f8d90433 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1225,7 +1225,7 @@ public final int hashCode() { } } - private enum VectorIndexType { + public enum VectorIndexType { HNSW("hnsw", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index 3d13f3cd82b9c..35906940a6418 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -17,6 +17,7 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; @@ -24,6 +25,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Comparator; import java.util.Objects; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -31,9 +33,8 @@ /** * A query that matches the provided docs with their scores. * - * Note: this query was adapted from Lucene's DocAndScoreQuery from the class + * Note: this query was originally adapted from Lucene's DocAndScoreQuery from the class * {@link org.apache.lucene.search.KnnFloatVectorQuery}, which is package-private. - * There are no changes to the behavior, just some renames. */ public class KnnScoreDocQuery extends Query { private final int[] docs; @@ -50,13 +51,18 @@ public class KnnScoreDocQuery extends Query { /** * Creates a query. * - * @param docs the global doc IDs of documents that match, in ascending order - * @param scores the scores of the matching documents + * @param scoreDocs an array of ScoreDocs to use for the query * @param reader IndexReader */ - KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { - this.docs = docs; - this.scores = scores; + KnnScoreDocQuery(ScoreDoc[] scoreDocs, IndexReader reader) { + // Ensure that the docs are sorted by docId, as they are later searched using binary search + Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); + this.docs = new int[scoreDocs.length]; + this.scores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + docs[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; + } this.segmentStarts = findSegmentStarts(reader, docs); this.contextIdentity = reader.getContext().id(); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index 6fa83ccfb6ac2..1a81f4b984e93 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -141,15 +141,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { - int numDocs = scoreDocs.length; - int[] docs = new int[numDocs]; - float[] scores = new float[numDocs]; - for (int i = 0; i < numDocs; i++) { - docs[i] = scoreDocs[i].doc; - scores[i] = scoreDocs[i].score; - } - - return new KnnScoreDocQuery(docs, scores, context.getIndexReader()); + return new KnnScoreDocQuery(scoreDocs, context.getIndexReader()); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 31d9767e9a857..99568a507ffb9 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -16,14 +16,12 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.Arrays; -import java.util.Comparator; import java.util.Objects; /** @@ -60,16 +58,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException { // Retrieve top k documents from the rescored query TopDocs topDocs = searcher.search(query, k); vectorOperations = topDocs.totalHits.value(); - ScoreDoc[] scoreDocs = topDocs.scoreDocs; - Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); - int[] docIds = new int[scoreDocs.length]; - float[] scores = new float[scoreDocs.length]; - for (int i = 0; i < scoreDocs.length; i++) { - docIds[i] = scoreDocs[i].doc; - scores[i] = scoreDocs[i].score; - } - - return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); + return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); } public Query innerQuery() { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 861a8b11db567..05b7bc9ef4f82 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -9,36 +9,39 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; +import org.elasticsearch.index.codec.Elasticsearch900Lucene101Codec; +import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; +import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.io.UnsupportedEncodingException; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashSet; -import java.util.Map; -import java.util.PriorityQueue; -import java.util.stream.Collectors; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -59,51 +62,45 @@ public void testRescoreDocs() throws Exception { // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query // and thus we're rescoring the top k docs. float[] queryVector = randomVector(numDims); + Query innerQuery; + if (randomBoolean()) { + innerQuery = new KnnFloatVectorQuery(FIELD_NAME, queryVector, (int) (k * randomFloatBetween(1.0f, 10.0f, true))); + } else { + innerQuery = new MatchAllDocsQuery(); + } RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, - new MatchAllDocsQuery() + innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false); - TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); - Map rescoredDocs = Arrays.stream(docs.scoreDocs) - .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); - - assertThat(rescoredDocs.size(), equalTo(k)); - - Collection rescoredScores = new HashSet<>(rescoredDocs.values()); - - // Collect all docs sequentially, and score them using the similarity function to get the top K scores - PriorityQueue topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); - - for (LeafReaderContext leafReaderContext : reader.leaves()) { - FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); - KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); - while (iterator.nextDoc() != NO_MORE_DOCS) { - float[] vectorData = vectorValues.vectorValue(iterator.docID()); - float score = VectorSimilarityFunction.COSINE.compare(queryVector, vectorData); - topK.add(score); - int docId = iterator.docID(); - // If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it - // to ensure we found them all - if (rescoredDocs.containsKey(docId)) { - assertThat(rescoredDocs.get(docId), equalTo(score)); - rescoredDocs.remove(docId); - } - } - } - - assertThat(rescoredDocs.size(), equalTo(0)); + TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs); + assertThat(rescoredDocs.scoreDocs.length, equalTo(k)); - // Check top scoring docs are contained in rescored docs - for (int i = 0; i < k; i++) { - Float topScore = topK.poll(); - if (rescoredScores.contains(topScore) == false) { - fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); + // Get real scores + DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE + ); + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource); + TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs); + + int i = 0; + ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs; + for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) { + // There are docs that won't be found in the rescored search, but every doc found must be in the same order + // and have the same score + while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) { + i++; + } + if (i >= realScoreDocs.length) { + fail("Rescored doc not found in real score docs"); } + assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score)); } } } @@ -205,16 +202,33 @@ public void profile(QueryProfiler queryProfiler) { } private static void addRandomDocuments(int numDocs, Directory d, int numDims) throws IOException { + IndexWriterConfig iwc = new IndexWriterConfig(); + // Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore + KnnVectorsFormat format = randomFrom( + new ES818BinaryQuantizedVectorsFormat(), + new ES818HnswBinaryQuantizedVectorsFormat(), + new ES813Int8FlatVectorFormat(), + new ES813Int8FlatVectorFormat(), + new ES814HnswScalarQuantizedVectorsFormat() + ); + iwc.setCodec(new Elasticsearch900Lucene101Codec(randomFrom(Zstd814StoredFieldsFormat.Mode.values())) { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return format; + } + }); try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { for (int i = 0; i < numDocs; i++) { Document document = new Document(); float[] vector = randomVector(numDims); - KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector, VectorSimilarityFunction.COSINE); document.add(vectorField); w.addDocument(document); + if (randomBoolean() && (i % 10 == 0)) { + w.commit(); + } } w.commit(); - w.forceMerge(1); } } }