Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.query;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType;
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
import static org.hamcrest.Matchers.equalTo;

public class RescoreKnnVectorQueryIT extends ESIntegTestCase {

public static final String INDEX_NAME = "test";
public static final String VECTOR_FIELD = "vector";
public static final String VECTOR_SCORE_SCRIPT = "vector_scoring";
public static final String QUERY_VECTOR_PARAM = "query_vector";

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(CustomScriptPlugin.class);
}

public static class CustomScriptPlugin extends MockScriptPlugin {
private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM
.vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT);

@Override
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
return Map.of(VECTOR_SCORE_SCRIPT, vars -> {
Map<?, ?> doc = (Map<?, ?>) vars.get("doc");
return SIMILARITY_FUNCTION.compare(
((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(),
(float[]) vars.get(QUERY_VECTOR_PARAM)
);
});
}
}

@Before
public void setup() throws IOException {
String type = randomFrom(
Arrays.stream(VectorIndexType.values())
.filter(VectorIndexType::isQuantized)
.map(t -> t.name().toLowerCase(Locale.ROOT))
.collect(Collectors.toCollection(ArrayList::new))
);
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(VECTOR_FIELD)
.field("type", "dense_vector")
.field("similarity", "l2_norm")
.startObject("index_options")
.field("type", type)
.endObject()
.endObject()
.endObject()
.endObject();

Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5))
.build();
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
ensureGreen(INDEX_NAME);
}

private record TestParams(
int numDocs,
int numDims,
float[] queryVector,
int k,
int numCands,
RescoreVectorBuilder rescoreVectorBuilder
) {
public static TestParams generate() {
int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions
int numDocs = randomIntBetween(10, 100);
int k = randomIntBetween(1, numDocs - 5);
return new TestParams(
numDocs,
numDims,
randomVector(numDims),
k,
(int) (k * randomFloatBetween(1.0f, 10.0f, true)),
new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true))
);
}
}

public void testKnnSearchRescore() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnSearchGenerator = (testParams, requestBuilder) -> {
KnnSearchBuilder knnSearch = new KnnSearchBuilder(
VECTOR_FIELD,
testParams.queryVector,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setKnnSearch(List.of(knnSearch));
};
testKnnRescore(knnSearchGenerator);
}

public void testKnnQueryRescore() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(
VECTOR_FIELD,
testParams.queryVector,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setQuery(knnQuery);
};
testKnnRescore(knnQueryGenerator);
}

public void testKnnRetriever() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
VECTOR_FIELD,
testParams.queryVector,
null,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever));
};
testKnnRescore(knnQueryGenerator);
}

private void testKnnRescore(BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> searchRequestGenerator) {
TestParams testParams = TestParams.generate();

int numDocs = testParams.numDocs;
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];

for (int i = 0; i < numDocs; i++) {
docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims));
}
indexRandom(true, docs);

float[] queryVector = testParams.queryVector;
float oversample = randomFloatBetween(1.0f, 100f, true);
RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample);

SearchRequestBuilder requestBuilder = searchRequestGenerator.apply(
testParams,
prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean())
);

assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); });
}

private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) {
// Do an exact query and compare
Script script = new Script(
ScriptType.INLINE,
CustomScriptPlugin.NAME,
VECTOR_SCORE_SCRIPT,
Map.of(QUERY_VECTOR_PARAM, queryVector)
);
ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script);
assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> {
assertHitCount(exactResponse, docCount);

int i = 0;
SearchHit[] exactHits = exactResponse.getHits().getHits();
for (SearchHit knnHit : knnResponse.getHits().getHits()) {
while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) {
i++;
}
if (i >= exactHits.length) {
fail("Knn doc not found in exact search");
}
assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore()));
}
});
}

private static float[] randomVector(int numDimensions) {
float[] vector = new float[numDimensions];
for (int j = 0; j < numDimensions; j++) {
vector[j] = randomFloatBetween(0, 1, true);
}
return vector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,7 @@ public final int hashCode() {
}
}

private enum VectorIndexType {
public enum VectorIndexType {
HNSW("hnsw", false) {
@Override
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

/**
* A query that matches the provided docs with their scores.
*
* Note: this query was adapted from Lucene's DocAndScoreQuery from the class
* Note: this query was originally adapted from Lucene's DocAndScoreQuery from the class
* {@link org.apache.lucene.search.KnnFloatVectorQuery}, which is package-private.
* There are no changes to the behavior, just some renames.
*/
public class KnnScoreDocQuery extends Query {
private final int[] docs;
Expand All @@ -49,13 +50,18 @@ public class KnnScoreDocQuery extends Query {
/**
* Creates a query.
*
* @param docs the global doc IDs of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param scoreDocs an array of ScoreDocs to use for the query
* @param reader IndexReader
*/
KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) {
this.docs = docs;
this.scores = scores;
KnnScoreDocQuery(ScoreDoc[] scoreDocs, IndexReader reader) {
// Ensure that the docs are sorted by docId, as they are later searched using binary search
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
this.docs = new int[scoreDocs.length];
this.scores = new float[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
docs[i] = scoreDocs[i].doc;
scores[i] = scoreDocs[i].score;
}
this.segmentStarts = findSegmentStarts(reader, docs);
this.contextIdentity = reader.getContext().id();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
int numDocs = scoreDocs.length;
int[] docs = new int[numDocs];
float[] scores = new float[numDocs];
for (int i = 0; i < numDocs; i++) {
docs[i] = scoreDocs[i].doc;
scores[i] = scoreDocs[i].score;
}

return new KnnScoreDocQuery(docs, scores, context.getIndexReader());
return new KnnScoreDocQuery(scoreDocs, context.getIndexReader());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
import org.elasticsearch.search.profile.query.QueryProfiler;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;

/**
Expand Down Expand Up @@ -60,16 +58,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
// Retrieve top k documents from the rescored query
TopDocs topDocs = searcher.search(query, k);
vectorOperations = topDocs.totalHits.value;
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
int[] docIds = new int[scoreDocs.length];
float[] scores = new float[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
docIds[i] = scoreDocs[i].doc;
scores[i] = scoreDocs[i].score;
}

return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
}

public Query innerQuery() {
Expand Down
Loading