diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e418797b7..59cb943a0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements * Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688) * Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684) +* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696) ### Bug Fixes * Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692) ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 0976441090..7b621718c6 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -5,9 +5,10 @@ package org.opensearch.knn.index; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; @@ -41,10 +42,29 @@ public long ramBytesUsed() { @Override public ScriptDocValues getScriptValues() { try { - BinaryDocValues values = DocValues.getBinary(reader, fieldName); - return new KNNVectorScriptDocValues(values, fieldName, vectorDataType); + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName); + if (fieldInfo == null) { + return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); + } + + DocIdSetIterator values; + if (fieldInfo.hasVectorValues()) { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + values = reader.getFloatVectorValues(fieldName); + break; + case BYTE: + values = reader.getByteVectorValues(fieldName); + break; + default: + throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding()); + } + } else { + values = DocValues.getBinary(reader, fieldName); + } + return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); + throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e); } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 349988c939..55ff655167 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -6,30 +6,40 @@ package org.opensearch.knn.index; import java.io.IOException; +import java.util.Objects; +import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import java.io.IOException; - -@RequiredArgsConstructor -public final class KNNVectorScriptDocValues extends ScriptDocValues { +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public abstract class KNNVectorScriptDocValues extends ScriptDocValues { - private final BinaryDocValues binaryDocValues; + private final DocIdSetIterator vectorValues; private final String fieldName; @Getter private final VectorDataType vectorDataType; private boolean docExists = false; + private int lastDocID = -1; @Override public void setNextDocId(int docId) throws IOException { - if (binaryDocValues.advanceExact(docId)) { - docExists = true; - return; + if (docId < lastDocID) { + throw new IllegalArgumentException("docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + docId); + } + + lastDocID = docId; + + int curDocID = vectorValues.docID(); + if (lastDocID > curDocID) { + curDocID = vectorValues.advance(docId); } - docExists = false; + docExists = lastDocID == curDocID; } public float[] getValue() { @@ -44,12 +54,14 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - return vectorDataType.getVectorFromBytesRef(binaryDocValues.binaryValue()); + return doGetValue(); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } } + protected abstract float[] doGetValue() throws IOException; + @Override public int size() { return docExists ? 1 : 0; @@ -59,4 +71,89 @@ public int size() { public float[] get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } + + /** + * Creates a KNNVectorScriptDocValues object based on the provided parameters. + * + * @param values The DocIdSetIterator representing the vector values. + * @param fieldName The name of the field. + * @param vectorDataType The data type of the vector. + * @return A KNNVectorScriptDocValues object based on the type of the values. + * @throws IllegalArgumentException If the type of values is unsupported. + */ + public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + Objects.requireNonNull(values, "values must not be null"); + if (values instanceof ByteVectorValues) { + return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof FloatVectorValues) { + return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof BinaryDocValues) { + return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); + } else { + throw new IllegalArgumentException("Unsupported values type: " + values.getClass()); + } + } + + private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { + private final ByteVectorValues values; + + KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + byte[] bytes = values.vectorValue(); + float[] value = new float[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + value[i] = (float) bytes[i]; + } + return value; + } + } + + private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { + private final FloatVectorValues values; + + KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return values.vectorValue(); + } + } + + private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { + private final BinaryDocValues values; + + KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return getVectorDataType().getVectorFromBytesRef(values.binaryValue()); + } + } + + /** + * Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type. + * + * @param fieldName The name of the field. + * @param type The data type of the vector. + * @return An empty KNNVectorScriptDocValues object. + */ + public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + @Override + protected float[] doGetValue() throws IOException { + throw new UnsupportedOperationException("empty values"); + } + }; + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 3f98a9136f..66e2893c0e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -5,7 +5,15 @@ package org.opensearch.knn.index; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; @@ -33,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - createKNNVectorDocument(directory); + Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); + createKNNVectorDocument(directory, valuesClass); reader = DirectoryReader.open(directory); - LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( - leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), - MOCK_INDEX_FIELD_NAME, - VectorDataType.FLOAT - ); + LeafReader leafReader = reader.getContext().leaves().get(0).reader(); + DocIdSetIterator vectorValues; + if (BinaryDocValues.class.equals(valuesClass)) { + vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME); + } else if (ByteVectorValues.class.equals(valuesClass)) { + vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME); + } else { + vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME); + } + + scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); } - private void createKNNVectorDocument(Directory directory) throws IOException { + private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( + Field field; + if (BinaryDocValues.class.equals(valuesClass)) { + field = new BinaryDocValuesField( MOCK_INDEX_FIELD_NAME, new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + ); + } else if (ByteVectorValues.class.equals(valuesClass)) { + field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); + } else { + field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + } + + knnDocument.add(field); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -84,4 +105,18 @@ public void testSize() throws IOException { public void testGet() throws IOException { expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); } + + public void testUnsupportedValues() throws IOException { + expectThrows( + IllegalArgumentException.class, + () -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT) + ); + } + + public void testEmptyValues() throws IOException { + KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + assertEquals(0, values.size()); + scriptDocValues.setNextDocId(0); + assertEquals(0, values.size()); + } } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 4423c85d8f..19270717d1 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { createKNNFloatVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, VectorDataType.FLOAT @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { createKNNByteVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, VectorDataType.BYTE diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 8c43a4acf3..22110accd0 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( + scriptDocValues = KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName, VectorDataType.FLOAT diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 901511a68b..5a83891d9d 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -7,6 +7,7 @@ import java.util.function.BiFunction; import java.util.function.Function; +import org.opensearch.ExceptionsHelper; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; @@ -193,7 +194,7 @@ public void testUnequalDimensions() throws Exception { } @SuppressWarnings("unchecked") - public void testKNNScoreforNonVectorDocument() throws Exception { + public void testKNNScoreForNonVectorDocument() throws Exception { /* * Create knn index and populate data */ @@ -599,7 +600,7 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { if (spaceType != SpaceType.HAMMING_BIT) { final float[] queryVector = randomVector(dimensions); final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); - createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector); + createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector, true); } } } @@ -612,7 +613,16 @@ private List createMappers(int dimensions) throws Exception { dimensions, KNNConstants.METHOD_HNSW, KNNEngine.LUCENE.getName(), - SpaceType.DEFAULT.getValue() + SpaceType.DEFAULT.getValue(), + true + ), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + false ) ); } @@ -625,12 +635,22 @@ private float[] randomVector(int dimensions) { return vector; } - private Map createDataset(Function scoreFunction, int dimensions, int numDocs) { - final Map dataset = new HashMap<>(numDocs); - for (int i = 0; i < numDocs; i++) { + private Map createDataset( + Function scoreFunction, + int dimensions, + int numDocsWithField, + boolean dense + ) { + final Map dataset = new HashMap<>(dense ? numDocsWithField : numDocsWithField * 3); + int id = 0; + for (int i = 0; i < numDocsWithField; i++) { + final int dummyDocs = dense ? 0 : randomIntBetween(2, 5); + for (int j = 0; j < dummyDocs; j++) { + dataset.put(Integer.toString(id++), null); + } final float[] vector = randomVector(dimensions); final float score = scoreFunction.apply(vector); - dataset.put(Integer.toString(i), new KNNResult(Integer.toString(i), vector, score)); + dataset.put(Integer.toString(id), new KNNResult(Integer.toString(id++), vector, score)); } return dataset; } @@ -669,7 +689,8 @@ private void testKNNScriptScore(SpaceType spaceType) throws Exception { final float[] queryVector = randomVector(dims); final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); for (String mapper : createMappers(dims)) { - createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector); + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true); + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false); } } @@ -678,16 +699,20 @@ private void createIndexAndAssertScriptScore( SpaceType spaceType, BiFunction scoreFunction, int dimensions, - float[] queryVector + float[] queryVector, + boolean dense ) throws Exception { /* * Create knn index and populate data */ createKnnIndex(INDEX_NAME, mapper); - Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, randomIntBetween(4, 10)); - for (Map.Entry entry : dataset.entrySet()) { - addKnnDoc(INDEX_NAME, entry.getKey(), FIELD_NAME, entry.getValue().getVector()); - } + final int numDocsWithField = randomIntBetween(4, 10); + Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, numDocsWithField, dense); + final float[] dummyVector = new float[1]; + dataset.forEach((k, v) -> { + final float[] vector = (v != null) ? v.getVector() : dummyVector; + ExceptionsHelper.catchAsRuntimeException(() -> addKnnDoc(INDEX_NAME, k, (v != null) ? FIELD_NAME : "dummy", vector)); + }); /** * Construct Search Request @@ -703,7 +728,7 @@ private void createIndexAndAssertScriptScore( params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", spaceType.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, numDocsWithField); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java index 0315c47c52..5325d12053 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java @@ -563,7 +563,9 @@ public void testL2ScriptingWithLuceneBackedIndex() throws Exception { new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); properties.add( - new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext) + new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2") + .knnMethodContext(knnMethodContext) + .docValues(randomBoolean()) ); String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index c8baa6fe41..e665a9f1e2 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -39,6 +39,7 @@ import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.core.rest.RestStatus; import org.opensearch.script.Script; +import org.opensearch.search.SearchService; import org.opensearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; import javax.management.MBeanServerInvocationHandler; @@ -974,9 +975,16 @@ protected Request constructScriptScoreContextSearchRequest( } protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params) throws Exception { + return constructKNNScriptQueryRequest(indexName, qb, params, SearchService.DEFAULT_SIZE); + } + + protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params, int size) + throws Exception { Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, KNNScoringScriptEngine.NAME, KNNScoringScriptEngine.SCRIPT_SOURCE, params); ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("size", size); + builder.startObject("query"); builder.startObject("script_score"); builder.field("query"); sc.query().toXContent(builder, ToXContent.EMPTY_PARAMS);