From e9fedcb3d893718b42471924280399cc84563b43 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 21 Aug 2024 09:16:53 -0700 Subject: [PATCH 01/41] Add quantization state reader and writer Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReader.java | 122 ++++++++++++++++++ .../KNNQuantizationStateWriter.java | 89 +++++++++++++ .../NativeEngines990KnnVectorsWriter.java | 17 ++- 3 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java new file mode 100644 index 0000000000..8869cccace --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Reads quantization states + */ +public class KNNQuantizationStateReader { + + /** + * Read quantization states and return list of fieldNames and bytes + * @param state the read state to read from + */ + public Map read(SegmentReadState state) { + String quantizationStateFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, "qs"); + Map readQuantizationStateInfos = new HashMap<>(); + + try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { + + long footerStart = input.length() - CodecUtil.footerLength(); + long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; + input.seek(markerAndIndexPosition); + long indexStartPosition = input.readLong(); + input.readInt(); + input.seek(indexStartPosition); + int numFields = input.readInt(); + + List fieldNames = new ArrayList<>(); + List positions = new ArrayList<>(); + List lengths = new ArrayList<>(); + + // Read each field's metadata from the index section + for (int i = 0; i < numFields; i++) { + fieldNames.add(input.readString()); + int length = input.readInt(); + lengths.add(length); + long position = input.readVLong(); + positions.add(position); + } + // Read each field's bytes + for (int i = 0; i < numFields; i++) { + input.seek(positions.get(i)); + byte[] stateBytes = new byte[lengths.get(i)]; + input.readBytes(stateBytes, 0, lengths.get(i)); + readQuantizationStateInfos.put(fieldNames.get(i), stateBytes); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + return readQuantizationStateInfos; + } + + /** + * Reads an individual quantization state for a given field + * @param directory directory to open input + * @param segmentName segment name + * @param segmentSuffix segment suffix + * @param fieldInfo field information + * @return quantization state + */ + public QuantizationState read(Directory directory, String segmentName, String segmentSuffix, FieldInfo fieldInfo) { + String quantizationStateFileName = IndexFileNames.segmentFileName(segmentName, segmentSuffix, "qs"); + String fieldName = fieldInfo.getName(); + + try (IndexInput input = directory.openInput(quantizationStateFileName, IOContext.READ)) { + + long footerStart = input.length() - CodecUtil.footerLength(); + long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; + input.seek(markerAndIndexPosition); + long indexStartPosition = input.readLong(); + input.readInt(); + input.seek(indexStartPosition); + int numFields = input.readInt(); + + long position = -1; + int length = 0; + + // Read each field's metadata from the index section + for (int i = 0; i < numFields; i++) { + String tempFieldName = input.readString(); + int tempLength = input.readInt(); + long tempPosition = input.readVLong(); + if (tempFieldName.equals(fieldName)) { + position = tempPosition; + length = tempLength; + break; + } + } + + if (position == -1 || length == 0) { + throw new IllegalArgumentException(String.format("Field %s not found", fieldName)); + } + + input.seek(position); + byte[] stateBytes = new byte[length]; + input.readBytes(stateBytes, 0, length); + // Deserialize the byte array to a quantization state object + // TODO: Get params from field info and deserialize + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java new file mode 100644 index 0000000000..f4dc1fd7ca --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.AllArgsConstructor; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.IndexOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Writes quantization states to off heap memory + */ +public class KNNQuantizationStateWriter { + + private final IndexOutput output; + private List fieldQuantizationStates = new ArrayList<>(); + + /** + * Constructor + * @param segmentWriteState segment write state containing segment information + * @throws IOException exception could be thrown while creating the output + */ + public KNNQuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException { + String quantizationStateFileName = IndexFileNames.segmentFileName( + segmentWriteState.segmentInfo.name, + segmentWriteState.segmentSuffix, + "qs" + ); + + output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context); + } + + /** + * Writes an index header + * @param segmentWriteState state containing segment information + * @throws IOException exception could be thrown while writing header + */ + public void writeHeader(SegmentWriteState segmentWriteState) throws IOException { + CodecUtil.writeIndexHeader(output, "QuantizationCodec", 0, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix); + } + + /** + * Writes a quantization state as bytes + * @param fieldName field name + * @param quantizationState quantization state + * @throws IOException could be thrown while writing + */ + public void writeState(String fieldName, QuantizationState quantizationState) throws IOException { + byte[] stateBytes = quantizationState.toByteArray(); + long position = output.getFilePointer(); + output.writeBytes(stateBytes, stateBytes.length); + fieldQuantizationStates.add(new FieldQuantizationState(fieldName, stateBytes, position)); + } + + /** + * Writes index footer and other index information for parsing later + * @throws IOException could be thrown while writing + */ + public void writeFooter() throws IOException { + long indexStartPosition = output.getFilePointer(); + output.writeInt(fieldQuantizationStates.size()); + for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) { + output.writeString(fieldQuantizationState.fieldName); + output.writeInt(fieldQuantizationState.stateBytes.length); + output.writeVLong(fieldQuantizationState.position); + } + output.writeLong(indexStartPosition); + output.writeInt(-1); + CodecUtil.writeFooter(output); + output.close(); + fieldQuantizationStates = new ArrayList<>(); + } + + @AllArgsConstructor + static class FieldQuantizationState { + final String fieldName; + final byte[] stateBytes; + final Long position; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 65736a63ef..876d8a1f73 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -11,7 +11,6 @@ package org.opensearch.knn.index.codec.KNN990Codec; -import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; @@ -39,14 +38,21 @@ * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. */ @Log4j2 -@RequiredArgsConstructor public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; + private final KNNQuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; + public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) throws IOException { + this.segmentWriteState = segmentWriteState; + this.flatVectorsWriter = flatVectorsWriter; + this.quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + + } + /** * Add new field for indexing. * In Lucene, we use single file for all the vector fields so here we need to see how we are going to make things @@ -70,6 +76,9 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { // simply write data in the flat file flatVectorsWriter.flush(maxDoc, sortMap); + + quantizationStateWriter.writeHeader(segmentWriteState); + for (final NativeEngineFieldVectorsWriter field : fields) { final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo()); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( @@ -78,8 +87,12 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); + // TODO: Extract quantization state here, uncomment below line once implemented + // quantizationStateWriter.writeState(field.getFieldInfo().getName(), quantizationState); + NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues); } + quantizationStateWriter.writeFooter(); } @Override From 1f5c03003956fbd84e33a34d7db90247ec0bc791 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 21 Aug 2024 09:18:39 -0700 Subject: [PATCH 02/41] Make inner class private Signed-off-by: Ryan Bogan --- .../knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java index f4dc1fd7ca..45b764a262 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java @@ -81,7 +81,7 @@ public void writeFooter() throws IOException { } @AllArgsConstructor - static class FieldQuantizationState { + private static class FieldQuantizationState { final String fieldName; final byte[] stateBytes; final Long position; From 627ac7e0675bc805a53c3eb07c470d02cff4dd47 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 21 Aug 2024 13:19:24 -0700 Subject: [PATCH 03/41] Address PR Feedback Signed-off-by: Ryan Bogan --- .../opensearch/knn/common/KNNConstants.java | 1 + .../KNNQuantizationStateReader.java | 55 +++++++++++++------ .../KNNQuantizationStateWriter.java | 3 +- 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 56f9ffaf89..76b0f61027 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -71,6 +71,7 @@ public class KNNConstants { public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; public static final String RADIAL_SEARCH_KEY = "radial_search"; + public static final String QUANTIZATION_STATE_FILE_SUFFIX = "qs"; // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index 8869cccace..f2687a206b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -12,6 +12,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -27,21 +28,35 @@ public class KNNQuantizationStateReader { /** * Read quantization states and return list of fieldNames and bytes + * + * File format: + * Header + * QS1 state bytes + * QS2 state bytes + * Number of quantization states + * QS1 field name + * QS1 state bytes length + * QS1 position of state bytes + * QS2 field name + * QS2 state bytes length + * QS2 position of state bytes + * Position of index section (where QS1 field name is located) + * -1 (marker) + * Footer + * * @param state the read state to read from */ public Map read(SegmentReadState state) { - String quantizationStateFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, "qs"); + String quantizationStateFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX + ); Map readQuantizationStateInfos = new HashMap<>(); try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { - long footerStart = input.length() - CodecUtil.footerLength(); - long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; - input.seek(markerAndIndexPosition); - long indexStartPosition = input.readLong(); - input.readInt(); - input.seek(indexStartPosition); - int numFields = input.readInt(); + int numFields = getNumFields(input); List fieldNames = new ArrayList<>(); List positions = new ArrayList<>(); @@ -77,18 +92,16 @@ public Map read(SegmentReadState state) { * @return quantization state */ public QuantizationState read(Directory directory, String segmentName, String segmentSuffix, FieldInfo fieldInfo) { - String quantizationStateFileName = IndexFileNames.segmentFileName(segmentName, segmentSuffix, "qs"); + String quantizationStateFileName = IndexFileNames.segmentFileName( + segmentName, + segmentSuffix, + KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX + ); String fieldName = fieldInfo.getName(); try (IndexInput input = directory.openInput(quantizationStateFileName, IOContext.READ)) { - long footerStart = input.length() - CodecUtil.footerLength(); - long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; - input.seek(markerAndIndexPosition); - long indexStartPosition = input.readLong(); - input.readInt(); - input.seek(indexStartPosition); - int numFields = input.readInt(); + int numFields = getNumFields(input); long position = -1; int length = 0; @@ -119,4 +132,14 @@ public QuantizationState read(Directory directory, String segmentName, String se throw new RuntimeException(e); } } + + private int getNumFields(IndexInput input) throws IOException { + long footerStart = input.length() - CodecUtil.footerLength(); + long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; + input.seek(markerAndIndexPosition); + long indexStartPosition = input.readLong(); + input.readInt(); + input.seek(indexStartPosition); + return input.readInt(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java index 45b764a262..088aade419 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.store.IndexOutput; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -33,7 +34,7 @@ public KNNQuantizationStateWriter(SegmentWriteState segmentWriteState) throws IO String quantizationStateFileName = IndexFileNames.segmentFileName( segmentWriteState.segmentInfo.name, segmentWriteState.segmentSuffix, - "qs" + KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX ); output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context); From a81e99d643b6fab2721aee2991e794d4b5c96dec Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 21 Aug 2024 17:03:07 -0700 Subject: [PATCH 04/41] Fix tests Signed-off-by: Ryan Bogan --- ...NativeEngines990KnnVectorsFormatTests.java | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 3810d46fd2..a6ef91cbf9 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -23,22 +23,28 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReader; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Sort; import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.store.BaseDirectoryWrapper; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.Version; import org.junit.After; import org.junit.Assert; import org.mockito.Mockito; @@ -53,6 +59,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; @Log4j2 @@ -76,7 +83,32 @@ public void tearDown() throws Exception { @SneakyThrows public void testReaderAndWriter_whenValidInput_thenSuccess() { final Lucene99FlatVectorsFormat mockedFlatVectorsFormat = Mockito.mock(Lucene99FlatVectorsFormat.class); - final SegmentWriteState mockedSegmentWriteState = Mockito.mock(SegmentWriteState.class); + + final String segmentName = "test-segment-name"; + + final SegmentInfo mockedSegmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + final SegmentWriteState mockedSegmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + Mockito.mock(Directory.class), + mockedSegmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); final SegmentReadState mockedSegmentReadState = Mockito.mock(SegmentReadState.class); Mockito.when(mockedFlatVectorsFormat.fieldsReader(mockedSegmentReadState)).thenReturn(Mockito.mock(FlatVectorsReader.class)); From daa39edad11e7d65e1cd4fcca0e4c0c9b51be3de Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 21 Aug 2024 17:07:49 -0700 Subject: [PATCH 05/41] Address PR feedback Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReader.java | 58 +++++++++---------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index f2687a206b..0fc8ea4964 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -91,7 +91,7 @@ public Map read(SegmentReadState state) { * @param fieldInfo field information * @return quantization state */ - public QuantizationState read(Directory directory, String segmentName, String segmentSuffix, FieldInfo fieldInfo) { + public QuantizationState read(Directory directory, String segmentName, String segmentSuffix, FieldInfo fieldInfo) throws IOException { String quantizationStateFileName = IndexFileNames.segmentFileName( segmentName, segmentSuffix, @@ -99,38 +99,36 @@ public QuantizationState read(Directory directory, String segmentName, String se ); String fieldName = fieldInfo.getName(); - try (IndexInput input = directory.openInput(quantizationStateFileName, IOContext.READ)) { - - int numFields = getNumFields(input); - - long position = -1; - int length = 0; - - // Read each field's metadata from the index section - for (int i = 0; i < numFields; i++) { - String tempFieldName = input.readString(); - int tempLength = input.readInt(); - long tempPosition = input.readVLong(); - if (tempFieldName.equals(fieldName)) { - position = tempPosition; - length = tempLength; - break; - } - } - - if (position == -1 || length == 0) { - throw new IllegalArgumentException(String.format("Field %s not found", fieldName)); + IndexInput input = directory.openInput(quantizationStateFileName, IOContext.READ); + CodecUtil.retrieveChecksum(input); + int numFields = getNumFields(input); + + long position = -1; + int length = 0; + + // Read each field's metadata from the index section + for (int i = 0; i < numFields; i++) { + String tempFieldName = input.readString(); + int tempLength = input.readInt(); + long tempPosition = input.readVLong(); + if (tempFieldName.equals(fieldName)) { + position = tempPosition; + length = tempLength; + break; } + } - input.seek(position); - byte[] stateBytes = new byte[length]; - input.readBytes(stateBytes, 0, length); - // Deserialize the byte array to a quantization state object - // TODO: Get params from field info and deserialize - return null; - } catch (IOException e) { - throw new RuntimeException(e); + if (position == -1 || length == 0) { + throw new IllegalArgumentException(String.format("Field %s not found", fieldName)); } + + input.seek(position); + byte[] stateBytes = new byte[length]; + input.readBytes(stateBytes, 0, length); + input.close(); + // Deserialize the byte array to a quantization state object + // TODO: Get params from field info and deserialize + return null; } private int getNumFields(IndexInput input) throws IOException { From 8abf3cd56707018564ceded7a5f42439894aabc0 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 22 Aug 2024 11:12:59 -0700 Subject: [PATCH 06/41] Add writer tests Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateWriterTests.java | 191 ++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java new file mode 100644 index 0000000000..95d3992dfa --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.Version; +import org.junit.Before; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; + +public class KNNQuantizationStateWriterTests extends KNNTestCase { + + @SneakyThrows + public void testWriteHeader() { + final String segmentName = "test-segment-name"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexOutput output = Mockito.mock(IndexOutput.class); + Mockito.when(directory.createOutput(any(), any())).thenReturn(output); + + final SegmentWriteState segmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); + KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + try (MockedStatic mockedStaticCodecUtil = Mockito.mockStatic(CodecUtil.class)) { + mockedStaticCodecUtil.when( + () -> CodecUtil.writeIndexHeader(any(IndexOutput.class), anyString(), anyInt(), any(byte[].class), anyString()) + ).thenAnswer((Answer) invocation -> null); + quantizationStateWriter.writeHeader(segmentWriteState); + mockedStaticCodecUtil.verify( + () -> CodecUtil.writeIndexHeader( + output, + "QuantizationCodec", + 0, + segmentWriteState.segmentInfo.getId(), + segmentWriteState.segmentSuffix + ) + ); + } + } + + @SneakyThrows + public void testWriteState() { + final String segmentName = "test-segment-name"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexOutput output = Mockito.mock(IndexOutput.class); + Mockito.when(directory.createOutput(any(), any())).thenReturn(output); + + final SegmentWriteState segmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); + KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + + String fieldName = "test-field"; + QuantizationState quantizationState = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f, 4.5f } + ); + quantizationStateWriter.writeState(fieldName, quantizationState); + byte[] stateBytes = quantizationState.toByteArray(); + Mockito.verify(output, times(1)).writeBytes(stateBytes, stateBytes.length); + } + + @SneakyThrows + public void testWriteFooter() { + final String segmentName = "test-segment-name"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexOutput output = Mockito.mock(IndexOutput.class); + Mockito.when(directory.createOutput(any(), any())).thenReturn(output); + + final SegmentWriteState segmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); + KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + + String fieldName1 = "test-field-1"; + String fieldName2 = "test-field-2"; + QuantizationState quantizationState1 = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f, 4.5f } + ); + QuantizationState quantizationState2 = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 2.3f, 3.4f, 4.5f, 5.6f } + ); + quantizationStateWriter.writeState(fieldName1, quantizationState1); + quantizationStateWriter.writeState(fieldName2, quantizationState2); + + try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { + quantizationStateWriter.writeFooter(); + + Mockito.verify(output, times(4)).writeInt(anyInt()); + Mockito.verify(output, times(2)).writeString(anyString()); + Mockito.verify(output, times(2)).writeVLong(anyLong()); + Mockito.verify(output, times(1)).writeLong(anyLong()); + mockedStaticCodecUtil.verify(() -> CodecUtil.writeFooter(output)); + } + } +} From e7d5ac8c9b54c49f104923198c1876050518e926 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 22 Aug 2024 11:58:38 -0700 Subject: [PATCH 07/41] Add reader tests Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReader.java | 50 +++++----- .../KNNQuantizationStateReaderTests.java | 92 +++++++++++++++++++ .../KNNQuantizationStateWriterTests.java | 47 +++++----- 3 files changed, 139 insertions(+), 50 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index 0fc8ea4964..ba52a5239a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.KNN990Codec; +import com.google.common.annotations.VisibleForTesting; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; @@ -28,7 +29,6 @@ public class KNNQuantizationStateReader { /** * Read quantization states and return list of fieldNames and bytes - * * File format: * Header * QS1 state bytes @@ -46,7 +46,7 @@ public class KNNQuantizationStateReader { * * @param state the read state to read from */ - public Map read(SegmentReadState state) { + public Map read(SegmentReadState state) throws IOException { String quantizationStateFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, @@ -54,32 +54,31 @@ public Map read(SegmentReadState state) { ); Map readQuantizationStateInfos = new HashMap<>(); - try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { + IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ); + CodecUtil.retrieveChecksum(input); - int numFields = getNumFields(input); + int numFields = getNumFields(input); - List fieldNames = new ArrayList<>(); - List positions = new ArrayList<>(); - List lengths = new ArrayList<>(); + List fieldNames = new ArrayList<>(); + List positions = new ArrayList<>(); + List lengths = new ArrayList<>(); - // Read each field's metadata from the index section - for (int i = 0; i < numFields; i++) { - fieldNames.add(input.readString()); - int length = input.readInt(); - lengths.add(length); - long position = input.readVLong(); - positions.add(position); - } - // Read each field's bytes - for (int i = 0; i < numFields; i++) { - input.seek(positions.get(i)); - byte[] stateBytes = new byte[lengths.get(i)]; - input.readBytes(stateBytes, 0, lengths.get(i)); - readQuantizationStateInfos.put(fieldNames.get(i), stateBytes); - } - } catch (IOException e) { - throw new RuntimeException(e); + // Read each field's metadata from the index section + for (int i = 0; i < numFields; i++) { + fieldNames.add(input.readString()); + int length = input.readInt(); + lengths.add(length); + long position = input.readVLong(); + positions.add(position); } + // Read each field's bytes + for (int i = 0; i < numFields; i++) { + input.seek(positions.get(i)); + byte[] stateBytes = new byte[lengths.get(i)]; + input.readBytes(stateBytes, 0, lengths.get(i)); + readQuantizationStateInfos.put(fieldNames.get(i), stateBytes); + } + input.close(); return readQuantizationStateInfos; } @@ -131,7 +130,8 @@ public QuantizationState read(Directory directory, String segmentName, String se return null; } - private int getNumFields(IndexInput input) throws IOException { + @VisibleForTesting + int getNumFields(IndexInput input) throws IOException { long footerStart = input.length() - CodecUtil.footerLength(); long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; input.seek(markerAndIndexPosition); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java new file mode 100644 index 0000000000..167ae942ce --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Version; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; + +public class KNNQuantizationStateReaderTests extends KNNTestCase { + + @SneakyThrows + public void testReadFromSegmentReadState() { + final String segmentName = "test-segment-name"; + final String segmentSuffix = "test-segment-suffix"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexInput input = Mockito.mock(IndexInput.class); + Mockito.when(directory.openInput(any(), any())).thenReturn(input); + + final SegmentReadState segmentReadState = new SegmentReadState( + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + Mockito.mock(IOContext.class), + segmentSuffix + ); + + KNNQuantizationStateReader quantizationStateReader = Mockito.mock(KNNQuantizationStateReader.class); + Mockito.when(quantizationStateReader.getNumFields(input)).thenReturn(2); + Mockito.when(quantizationStateReader.read(segmentReadState)).thenCallRealMethod(); + + try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { + quantizationStateReader.read(segmentReadState); + + mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); + Mockito.verify(input, times(2)).readInt(); + Mockito.verify(input, times(2)).readString(); + Mockito.verify(input, times(2)).readVLong(); + Mockito.verify(input, times(2)).readBytes(any(byte[].class), anyInt(), anyInt()); + Mockito.verify(input, times(2)).seek(anyLong()); + } + } + + @SneakyThrows + public void testGetNumFields() { + IndexInput input = Mockito.mock(IndexInput.class); + KNNQuantizationStateReader quantizationStateReader = new KNNQuantizationStateReader(); + quantizationStateReader.getNumFields(input); + + Mockito.verify(input, times(2)).readInt(); + Mockito.verify(input, times(1)).readLong(); + Mockito.verify(input, times(2)).seek(anyLong()); + Mockito.verify(input, times(1)).length(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java index 95d3992dfa..0f1414a8a5 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java @@ -10,7 +10,6 @@ import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.search.Sort; import org.apache.lucene.store.Directory; @@ -18,14 +17,12 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.Version; -import org.junit.Before; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.stubbing.Answer; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; -import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -137,18 +134,18 @@ public void testWriteFooter() { final String segmentName = "test-segment-name"; final SegmentInfo segmentInfo = new SegmentInfo( - Mockito.mock(Directory.class), - Mockito.mock(Version.class), - Mockito.mock(Version.class), - segmentName, - 0, - false, - false, - Mockito.mock(Codec.class), - Mockito.mock(Map.class), - new byte[16], - Mockito.mock(Map.class), - Mockito.mock(Sort.class) + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) ); Directory directory = Mockito.mock(Directory.class); @@ -156,24 +153,24 @@ public void testWriteFooter() { Mockito.when(directory.createOutput(any(), any())).thenReturn(output); final SegmentWriteState segmentWriteState = new SegmentWriteState( - Mockito.mock(InfoStream.class), - directory, - segmentInfo, - Mockito.mock(FieldInfos.class), - null, - Mockito.mock(IOContext.class) + Mockito.mock(InfoStream.class), + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) ); KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); String fieldName1 = "test-field-1"; String fieldName2 = "test-field-2"; QuantizationState quantizationState1 = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f, 4.5f } + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f, 4.5f } ); QuantizationState quantizationState2 = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - new float[] { 2.3f, 3.4f, 4.5f, 5.6f } + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 2.3f, 3.4f, 4.5f, 5.6f } ); quantizationStateWriter.writeState(fieldName1, quantizationState1); quantizationStateWriter.writeState(fieldName2, quantizationState2); From c804652a1d8caf94e045a101df45078ffd3a696e Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 22 Aug 2024 12:00:06 -0700 Subject: [PATCH 08/41] Add changelog entry Signed-off-by: Ryan Bogan --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0c61ae55c..2797979111 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,3 +41,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) * Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957) * Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960) +* Add quantization state reader and writer [#1997](https://github.com/opensearch-project/k-NN/pull/1997) From f711e39ebff1035108869f600f49790249b9b18a Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 22 Aug 2024 12:16:02 -0700 Subject: [PATCH 09/41] Remove extra line Signed-off-by: Ryan Bogan --- .../codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 876d8a1f73..c1ea3d3951 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -50,7 +50,6 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; this.quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); - } /** From 425b920c5e2c589604bb3e163b5ff28add8b989a Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 22 Aug 2024 13:11:17 -0700 Subject: [PATCH 10/41] Address PR Feedback Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReader.java | 15 ++++++------- .../KNNQuantizationStateWriter.java | 22 +++++++++++++++++-- .../NativeEngines990KnnVectorsWriter.java | 1 + .../QuantizationStateReadConfig.java | 20 +++++++++++++++++ 4 files changed, 48 insertions(+), 10 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index ba52a5239a..26a5bc51bc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -7,14 +7,13 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import java.io.IOException; import java.util.ArrayList; @@ -25,7 +24,7 @@ /** * Reads quantization states */ -public class KNNQuantizationStateReader { +public final class KNNQuantizationStateReader { /** * Read quantization states and return list of fieldNames and bytes @@ -90,15 +89,15 @@ public Map read(SegmentReadState state) throws IOException { * @param fieldInfo field information * @return quantization state */ - public QuantizationState read(Directory directory, String segmentName, String segmentSuffix, FieldInfo fieldInfo) throws IOException { + public QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { String quantizationStateFileName = IndexFileNames.segmentFileName( - segmentName, - segmentSuffix, + readConfig.getSegmentName(), + readConfig.getSegmentSuffix(), KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX ); - String fieldName = fieldInfo.getName(); + String fieldName = readConfig.getFieldInfo().getName(); - IndexInput input = directory.openInput(quantizationStateFileName, IOContext.READ); + IndexInput input = readConfig.getDirectory().openInput(quantizationStateFileName, IOContext.READ); CodecUtil.retrieveChecksum(input); int numFields = getNumFields(input); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java index 088aade419..4b8c0ca098 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java @@ -20,13 +20,27 @@ /** * Writes quantization states to off heap memory */ -public class KNNQuantizationStateWriter { +public final class KNNQuantizationStateWriter { private final IndexOutput output; private List fieldQuantizationStates = new ArrayList<>(); /** * Constructor + * Overall file format for writer: + * Header + * QS1 state bytes + * QS2 state bytes + * Number of quantization states + * QS1 field name + * QS1 state bytes length + * QS1 position of state bytes + * QS2 field name + * QS2 state bytes length + * QS2 position of state bytes + * Position of index section (where QS1 field name is located) + * -1 (marker) + * Footer * @param segmentWriteState segment write state containing segment information * @throws IOException exception could be thrown while creating the output */ @@ -51,6 +65,7 @@ public void writeHeader(SegmentWriteState segmentWriteState) throws IOException /** * Writes a quantization state as bytes + * * @param fieldName field name * @param quantizationState quantization state * @throws IOException could be thrown while writing @@ -77,7 +92,6 @@ public void writeFooter() throws IOException { output.writeLong(indexStartPosition); output.writeInt(-1); CodecUtil.writeFooter(output); - output.close(); fieldQuantizationStates = new ArrayList<>(); } @@ -87,4 +101,8 @@ private static class FieldQuantizationState { final byte[] stateBytes; final Long position; } + + public void closeOutput() throws IOException { + output.close(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index c1ea3d3951..c6ece1f6e8 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -145,6 +145,7 @@ public void finish() throws IOException { */ @Override public void close() throws IOException { + quantizationStateWriter.closeOutput(); IOUtils.close(flatVectorsWriter); } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java new file mode 100644 index 0000000000..96a126a1e7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.store.Directory; + +@Getter +@AllArgsConstructor +public class QuantizationStateReadConfig { + private Directory directory; + private String segmentName; + private String segmentSuffix; + private FieldInfo fieldInfo; +} From 2ea5371ec041bdaf1bfe8b92ae68f5aac609e881 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 22 Aug 2024 13:23:36 -0700 Subject: [PATCH 11/41] Fix javadocs Signed-off-by: Ryan Bogan --- .../index/codec/KNN990Codec/KNNQuantizationStateReader.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index 26a5bc51bc..98c6781c11 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -83,10 +83,7 @@ public Map read(SegmentReadState state) throws IOException { /** * Reads an individual quantization state for a given field - * @param directory directory to open input - * @param segmentName segment name - * @param segmentSuffix segment suffix - * @param fieldInfo field information + * @param readConfig a config class that contains necessary information for reading the state * @return quantization state */ public QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { From 89e45de1a3e4fc2bc6cd3e115f197fff34c3df9a Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 22 Aug 2024 16:29:23 -0700 Subject: [PATCH 12/41] Make reader methods static Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReader.java | 96 +++++++++---------- .../KNNQuantizationStateReaderTests.java | 27 +++--- 2 files changed, 61 insertions(+), 62 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index 98c6781c11..727b5836d9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -45,7 +45,7 @@ public final class KNNQuantizationStateReader { * * @param state the read state to read from */ - public Map read(SegmentReadState state) throws IOException { + public static Map read(SegmentReadState state) throws IOException { String quantizationStateFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, @@ -53,31 +53,31 @@ public Map read(SegmentReadState state) throws IOException { ); Map readQuantizationStateInfos = new HashMap<>(); - IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ); - CodecUtil.retrieveChecksum(input); + try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { + CodecUtil.retrieveChecksum(input); - int numFields = getNumFields(input); + int numFields = getNumFields(input); - List fieldNames = new ArrayList<>(); - List positions = new ArrayList<>(); - List lengths = new ArrayList<>(); + List fieldNames = new ArrayList<>(); + List positions = new ArrayList<>(); + List lengths = new ArrayList<>(); - // Read each field's metadata from the index section - for (int i = 0; i < numFields; i++) { - fieldNames.add(input.readString()); - int length = input.readInt(); - lengths.add(length); - long position = input.readVLong(); - positions.add(position); - } - // Read each field's bytes - for (int i = 0; i < numFields; i++) { - input.seek(positions.get(i)); - byte[] stateBytes = new byte[lengths.get(i)]; - input.readBytes(stateBytes, 0, lengths.get(i)); - readQuantizationStateInfos.put(fieldNames.get(i), stateBytes); + // Read each field's metadata from the index section + for (int i = 0; i < numFields; i++) { + fieldNames.add(input.readString()); + int length = input.readInt(); + lengths.add(length); + long position = input.readVLong(); + positions.add(position); + } + // Read each field's bytes + for (int i = 0; i < numFields; i++) { + input.seek(positions.get(i)); + byte[] stateBytes = new byte[lengths.get(i)]; + input.readBytes(stateBytes, 0, lengths.get(i)); + readQuantizationStateInfos.put(fieldNames.get(i), stateBytes); + } } - input.close(); return readQuantizationStateInfos; } @@ -86,7 +86,7 @@ public Map read(SegmentReadState state) throws IOException { * @param readConfig a config class that contains necessary information for reading the state * @return quantization state */ - public QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { + public static QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { String quantizationStateFileName = IndexFileNames.segmentFileName( readConfig.getSegmentName(), readConfig.getSegmentSuffix(), @@ -94,40 +94,40 @@ public QuantizationState read(QuantizationStateReadConfig readConfig) throws IOE ); String fieldName = readConfig.getFieldInfo().getName(); - IndexInput input = readConfig.getDirectory().openInput(quantizationStateFileName, IOContext.READ); - CodecUtil.retrieveChecksum(input); - int numFields = getNumFields(input); - - long position = -1; - int length = 0; - - // Read each field's metadata from the index section - for (int i = 0; i < numFields; i++) { - String tempFieldName = input.readString(); - int tempLength = input.readInt(); - long tempPosition = input.readVLong(); - if (tempFieldName.equals(fieldName)) { - position = tempPosition; - length = tempLength; - break; + try (IndexInput input = readConfig.getDirectory().openInput(quantizationStateFileName, IOContext.READ)) { + CodecUtil.retrieveChecksum(input); + int numFields = getNumFields(input); + + long position = -1; + int length = 0; + + // Read each field's metadata from the index section, break when correct field is found + for (int i = 0; i < numFields; i++) { + String tempFieldName = input.readString(); + int tempLength = input.readInt(); + long tempPosition = input.readVLong(); + if (tempFieldName.equals(fieldName)) { + position = tempPosition; + length = tempLength; + break; + } } - } - if (position == -1 || length == 0) { - throw new IllegalArgumentException(String.format("Field %s not found", fieldName)); - } + if (position == -1 || length == 0) { + throw new IllegalArgumentException(String.format("Field %s not found", fieldName)); + } - input.seek(position); - byte[] stateBytes = new byte[length]; - input.readBytes(stateBytes, 0, length); - input.close(); + input.seek(position); + byte[] stateBytes = new byte[length]; + input.readBytes(stateBytes, 0, length); + } // Deserialize the byte array to a quantization state object // TODO: Get params from field info and deserialize return null; } @VisibleForTesting - int getNumFields(IndexInput input) throws IOException { + static int getNumFields(IndexInput input) throws IOException { long footerStart = input.length() - CodecUtil.footerLength(); long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; input.seek(markerAndIndexPosition); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java index 167ae942ce..d6d9ff5409 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java @@ -62,27 +62,26 @@ public void testReadFromSegmentReadState() { segmentSuffix ); - KNNQuantizationStateReader quantizationStateReader = Mockito.mock(KNNQuantizationStateReader.class); - Mockito.when(quantizationStateReader.getNumFields(input)).thenReturn(2); - Mockito.when(quantizationStateReader.read(segmentReadState)).thenCallRealMethod(); + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNNQuantizationStateReader.getNumFields(input)).thenReturn(2); + mockedStaticReader.when(() -> KNNQuantizationStateReader.read(segmentReadState)).thenCallRealMethod(); + try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { + KNNQuantizationStateReader.read(segmentReadState); - try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { - quantizationStateReader.read(segmentReadState); - - mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); - Mockito.verify(input, times(2)).readInt(); - Mockito.verify(input, times(2)).readString(); - Mockito.verify(input, times(2)).readVLong(); - Mockito.verify(input, times(2)).readBytes(any(byte[].class), anyInt(), anyInt()); - Mockito.verify(input, times(2)).seek(anyLong()); + mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); + Mockito.verify(input, times(2)).readInt(); + Mockito.verify(input, times(2)).readString(); + Mockito.verify(input, times(2)).readVLong(); + Mockito.verify(input, times(2)).readBytes(any(byte[].class), anyInt(), anyInt()); + Mockito.verify(input, times(2)).seek(anyLong()); + } } } @SneakyThrows public void testGetNumFields() { IndexInput input = Mockito.mock(IndexInput.class); - KNNQuantizationStateReader quantizationStateReader = new KNNQuantizationStateReader(); - quantizationStateReader.getNumFields(input); + KNNQuantizationStateReader.getNumFields(input); Mockito.verify(input, times(2)).readInt(); Mockito.verify(input, times(1)).readLong(); From 8cd2ee3a0c065d02cff789958478317ec914291e Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 22 Aug 2024 16:42:13 -0700 Subject: [PATCH 13/41] Integrate with merge Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateWriter.java | 16 +++++- .../NativeEngines990KnnVectorsWriter.java | 5 ++ .../KNNQuantizationStateWriterTests.java | 56 +++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java index 4b8c0ca098..ca1fd2e4e6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.codec.KNN990Codec; import lombok.AllArgsConstructor; +import lombok.Setter; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentWriteState; @@ -92,14 +93,25 @@ public void writeFooter() throws IOException { output.writeLong(indexStartPosition); output.writeInt(-1); CodecUtil.writeFooter(output); - fieldQuantizationStates = new ArrayList<>(); + } + + /** + * Writes the bytes of existing quantization states. Used during merge to rewrite file. + * @throws IOException exception could be thrown during write + */ + public void writeExistingStates() throws IOException { + for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) { + fieldQuantizationState.setPosition(output.getFilePointer()); + output.writeBytes(fieldQuantizationState.stateBytes, fieldQuantizationState.stateBytes.length); + } } @AllArgsConstructor private static class FieldQuantizationState { final String fieldName; final byte[] stateBytes; - final Long position; + @Setter + Long position; } public void closeOutput() throws IOException { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index c6ece1f6e8..fc4d284ab3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -115,6 +115,11 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); } + quantizationStateWriter.writeHeader(segmentWriteState); + quantizationStateWriter.writeExistingStates(); + // quantizationStateWriter.writeState(fieldInfo.getName(), quantizationState); + quantizationStateWriter.writeFooter(); + NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java index 0f1414a8a5..e079269340 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java @@ -185,4 +185,60 @@ public void testWriteFooter() { mockedStaticCodecUtil.verify(() -> CodecUtil.writeFooter(output)); } } + + @SneakyThrows + public void testWriteExistingStates() { + final String segmentName = "test-segment-name"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexOutput output = Mockito.mock(IndexOutput.class); + Mockito.when(directory.createOutput(any(), any())).thenReturn(output); + + final SegmentWriteState segmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); + KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + + quantizationStateWriter.writeExistingStates(); + + Mockito.verify(output, times(0)).writeBytes(any(byte[].class), anyInt()); + + String fieldName1 = "test-field-1"; + String fieldName2 = "test-field-2"; + QuantizationState quantizationState1 = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f, 4.5f } + ); + QuantizationState quantizationState2 = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 2.3f, 3.4f, 4.5f, 5.6f } + ); + quantizationStateWriter.writeState(fieldName1, quantizationState1); + quantizationStateWriter.writeState(fieldName2, quantizationState2); + + quantizationStateWriter.writeExistingStates(); + + // Should be called once in each write state method and once in the second call of writeExistingStates + Mockito.verify(output, times(4)).writeBytes(any(byte[].class), anyInt()); + } } From d644c9b2d5f8915fbcdd5a99308c33754afa0be6 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Fri, 23 Aug 2024 11:46:30 -0700 Subject: [PATCH 14/41] Change field name writing to internal field number and change file suffix Signed-off-by: Ryan Bogan --- .../opensearch/knn/common/KNNConstants.java | 2 +- .../KNNQuantizationStateReader.java | 12 +++++------ .../KNNQuantizationStateWriter.java | 14 ++++++------- .../NativeEngines990KnnVectorsWriter.java | 4 ++-- .../KNNQuantizationStateReaderTests.java | 3 +-- .../KNNQuantizationStateWriterTests.java | 20 +++++++++---------- 6 files changed, 27 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 76b0f61027..550b6188d6 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -71,7 +71,7 @@ public class KNNConstants { public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; public static final String RADIAL_SEARCH_KEY = "radial_search"; - public static final String QUANTIZATION_STATE_FILE_SUFFIX = "qs"; + public static final String QUANTIZATION_STATE_FILE_SUFFIX = "qstate"; // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index 727b5836d9..99a39e2489 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -33,10 +33,10 @@ public final class KNNQuantizationStateReader { * QS1 state bytes * QS2 state bytes * Number of quantization states - * QS1 field name + * QS1 field number * QS1 state bytes length * QS1 position of state bytes - * QS2 field name + * QS2 field number * QS2 state bytes length * QS2 position of state bytes * Position of index section (where QS1 field name is located) @@ -92,7 +92,7 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr readConfig.getSegmentSuffix(), KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX ); - String fieldName = readConfig.getFieldInfo().getName(); + int fieldNumber = readConfig.getFieldInfo().getFieldNumber(); try (IndexInput input = readConfig.getDirectory().openInput(quantizationStateFileName, IOContext.READ)) { CodecUtil.retrieveChecksum(input); @@ -103,10 +103,10 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr // Read each field's metadata from the index section, break when correct field is found for (int i = 0; i < numFields; i++) { - String tempFieldName = input.readString(); + int tempFieldNumber = input.readInt(); int tempLength = input.readInt(); long tempPosition = input.readVLong(); - if (tempFieldName.equals(fieldName)) { + if (tempFieldNumber == fieldNumber) { position = tempPosition; length = tempLength; break; @@ -114,7 +114,7 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr } if (position == -1 || length == 0) { - throw new IllegalArgumentException(String.format("Field %s not found", fieldName)); + throw new IllegalArgumentException(String.format("Field %s not found", readConfig.getFieldInfo().getName())); } input.seek(position); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java index ca1fd2e4e6..5d1e2976e5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java @@ -33,10 +33,10 @@ public final class KNNQuantizationStateWriter { * QS1 state bytes * QS2 state bytes * Number of quantization states - * QS1 field name + * QS1 field number * QS1 state bytes length * QS1 position of state bytes - * QS2 field name + * QS2 field number * QS2 state bytes length * QS2 position of state bytes * Position of index section (where QS1 field name is located) @@ -67,15 +67,15 @@ public void writeHeader(SegmentWriteState segmentWriteState) throws IOException /** * Writes a quantization state as bytes * - * @param fieldName field name + * @param fieldNumber field number * @param quantizationState quantization state * @throws IOException could be thrown while writing */ - public void writeState(String fieldName, QuantizationState quantizationState) throws IOException { + public void writeState(int fieldNumber, QuantizationState quantizationState) throws IOException { byte[] stateBytes = quantizationState.toByteArray(); long position = output.getFilePointer(); output.writeBytes(stateBytes, stateBytes.length); - fieldQuantizationStates.add(new FieldQuantizationState(fieldName, stateBytes, position)); + fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position)); } /** @@ -86,7 +86,7 @@ public void writeFooter() throws IOException { long indexStartPosition = output.getFilePointer(); output.writeInt(fieldQuantizationStates.size()); for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) { - output.writeString(fieldQuantizationState.fieldName); + output.writeInt(fieldQuantizationState.fieldNumber); output.writeInt(fieldQuantizationState.stateBytes.length); output.writeVLong(fieldQuantizationState.position); } @@ -108,7 +108,7 @@ public void writeExistingStates() throws IOException { @AllArgsConstructor private static class FieldQuantizationState { - final String fieldName; + final int fieldNumber; final byte[] stateBytes; @Setter Long position; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index fc4d284ab3..45e1972736 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -87,7 +87,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { ); // TODO: Extract quantization state here, uncomment below line once implemented - // quantizationStateWriter.writeState(field.getFieldInfo().getName(), quantizationState); + // quantizationStateWriter.writeState(field.getFieldInfo().getFieldNumber(), quantizationState); NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues); } @@ -117,7 +117,7 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState quantizationStateWriter.writeHeader(segmentWriteState); quantizationStateWriter.writeExistingStates(); - // quantizationStateWriter.writeState(fieldInfo.getName(), quantizationState); + // quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); quantizationStateWriter.writeFooter(); NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java index d6d9ff5409..cfe88b8d00 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java @@ -69,8 +69,7 @@ public void testReadFromSegmentReadState() { KNNQuantizationStateReader.read(segmentReadState); mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); - Mockito.verify(input, times(2)).readInt(); - Mockito.verify(input, times(2)).readString(); + Mockito.verify(input, times(4)).readInt(); Mockito.verify(input, times(2)).readVLong(); Mockito.verify(input, times(2)).readBytes(any(byte[].class), anyInt(), anyInt()); Mockito.verify(input, times(2)).seek(anyLong()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java index e079269340..e52938e612 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java @@ -119,12 +119,12 @@ public void testWriteState() { ); KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); - String fieldName = "test-field"; + int fieldNumber = 0; QuantizationState quantizationState = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f, 4.5f } ); - quantizationStateWriter.writeState(fieldName, quantizationState); + quantizationStateWriter.writeState(fieldNumber, quantizationState); byte[] stateBytes = quantizationState.toByteArray(); Mockito.verify(output, times(1)).writeBytes(stateBytes, stateBytes.length); } @@ -162,8 +162,8 @@ public void testWriteFooter() { ); KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); - String fieldName1 = "test-field-1"; - String fieldName2 = "test-field-2"; + int fieldNumber1 = 1; + int fieldNumber2 = 2; QuantizationState quantizationState1 = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f, 4.5f } @@ -172,8 +172,8 @@ public void testWriteFooter() { new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), new float[] { 2.3f, 3.4f, 4.5f, 5.6f } ); - quantizationStateWriter.writeState(fieldName1, quantizationState1); - quantizationStateWriter.writeState(fieldName2, quantizationState2); + quantizationStateWriter.writeState(fieldNumber1, quantizationState1); + quantizationStateWriter.writeState(fieldNumber2, quantizationState2); try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { quantizationStateWriter.writeFooter(); @@ -223,8 +223,8 @@ public void testWriteExistingStates() { Mockito.verify(output, times(0)).writeBytes(any(byte[].class), anyInt()); - String fieldName1 = "test-field-1"; - String fieldName2 = "test-field-2"; + int fieldNumber1 = 1; + int fieldNumber2 = 2; QuantizationState quantizationState1 = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f, 4.5f } @@ -233,8 +233,8 @@ public void testWriteExistingStates() { new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), new float[] { 2.3f, 3.4f, 4.5f, 5.6f } ); - quantizationStateWriter.writeState(fieldName1, quantizationState1); - quantizationStateWriter.writeState(fieldName2, quantizationState2); + quantizationStateWriter.writeState(fieldNumber1, quantizationState1); + quantizationStateWriter.writeState(fieldNumber2, quantizationState2); quantizationStateWriter.writeExistingStates(); From 92bb5392c268695e591a3a79321367e3730a46f9 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Mon, 26 Aug 2024 11:44:55 -0700 Subject: [PATCH 15/41] Change integration with native engine writer Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateWriter.java | 11 ---- .../NativeEngines990KnnVectorsWriter.java | 8 +-- .../KNNQuantizationStateWriterTests.java | 56 ------------------- 3 files changed, 2 insertions(+), 73 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java index 5d1e2976e5..03e392d7ba 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java @@ -95,17 +95,6 @@ public void writeFooter() throws IOException { CodecUtil.writeFooter(output); } - /** - * Writes the bytes of existing quantization states. Used during merge to rewrite file. - * @throws IOException exception could be thrown during write - */ - public void writeExistingStates() throws IOException { - for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) { - fieldQuantizationState.setPosition(output.getFilePointer()); - output.writeBytes(fieldQuantizationState.stateBytes, fieldQuantizationState.stateBytes.length); - } - } - @AllArgsConstructor private static class FieldQuantizationState { final int fieldNumber; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 45bab758fa..7875121de5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -54,6 +54,7 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; this.quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + quantizationStateWriter.writeHeader(segmentWriteState); } /** @@ -79,8 +80,6 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { flatVectorsWriter.flush(maxDoc, sortMap); - quantizationStateWriter.writeHeader(segmentWriteState); - for (final NativeEngineFieldVectorsWriter field : fields) { trainAndIndex( field.getFieldInfo(), @@ -96,12 +95,8 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - - quantizationStateWriter.writeHeader(segmentWriteState); - quantizationStateWriter.writeExistingStates(); // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs trainAndIndex(fieldInfo, this::getKNNVectorValuesForMerge, NativeIndexWriter::mergeIndex, mergeState); - quantizationStateWriter.writeFooter(); } /** @@ -113,6 +108,7 @@ public void finish() throws IOException { throw new IllegalStateException("NativeEnginesKNNVectorsWriter is already finished"); } finished = true; + quantizationStateWriter.writeFooter(); flatVectorsWriter.finish(); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java index e52938e612..f708a30e6c 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java @@ -185,60 +185,4 @@ public void testWriteFooter() { mockedStaticCodecUtil.verify(() -> CodecUtil.writeFooter(output)); } } - - @SneakyThrows - public void testWriteExistingStates() { - final String segmentName = "test-segment-name"; - - final SegmentInfo segmentInfo = new SegmentInfo( - Mockito.mock(Directory.class), - Mockito.mock(Version.class), - Mockito.mock(Version.class), - segmentName, - 0, - false, - false, - Mockito.mock(Codec.class), - Mockito.mock(Map.class), - new byte[16], - Mockito.mock(Map.class), - Mockito.mock(Sort.class) - ); - - Directory directory = Mockito.mock(Directory.class); - IndexOutput output = Mockito.mock(IndexOutput.class); - Mockito.when(directory.createOutput(any(), any())).thenReturn(output); - - final SegmentWriteState segmentWriteState = new SegmentWriteState( - Mockito.mock(InfoStream.class), - directory, - segmentInfo, - Mockito.mock(FieldInfos.class), - null, - Mockito.mock(IOContext.class) - ); - KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); - - quantizationStateWriter.writeExistingStates(); - - Mockito.verify(output, times(0)).writeBytes(any(byte[].class), anyInt()); - - int fieldNumber1 = 1; - int fieldNumber2 = 2; - QuantizationState quantizationState1 = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f, 4.5f } - ); - QuantizationState quantizationState2 = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - new float[] { 2.3f, 3.4f, 4.5f, 5.6f } - ); - quantizationStateWriter.writeState(fieldNumber1, quantizationState1); - quantizationStateWriter.writeState(fieldNumber2, quantizationState2); - - quantizationStateWriter.writeExistingStates(); - - // Should be called once in each write state method and once in the second call of writeExistingStates - Mockito.verify(output, times(4)).writeBytes(any(byte[].class), anyInt()); - } } From 5b03f30ace23b2d4741a4489afe8174df6da0c00 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Mon, 26 Aug 2024 12:06:00 -0700 Subject: [PATCH 16/41] Fix tests Signed-off-by: Ryan Bogan --- .../codec/KNN990Codec/KNNQuantizationStateReader.java | 7 ++++--- .../KNN990Codec/KNNQuantizationStateReaderTests.java | 9 ++++++++- .../KNN990Codec/KNNQuantizationStateWriterTests.java | 3 +-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index 99a39e2489..8e2044a563 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -58,13 +58,13 @@ public static Map read(SegmentReadState state) throws IOExceptio int numFields = getNumFields(input); - List fieldNames = new ArrayList<>(); + List fieldNumbers = new ArrayList<>(); List positions = new ArrayList<>(); List lengths = new ArrayList<>(); // Read each field's metadata from the index section for (int i = 0; i < numFields; i++) { - fieldNames.add(input.readString()); + fieldNumbers.add(input.readInt()); int length = input.readInt(); lengths.add(length); long position = input.readVLong(); @@ -75,7 +75,8 @@ public static Map read(SegmentReadState state) throws IOExceptio input.seek(positions.get(i)); byte[] stateBytes = new byte[lengths.get(i)]; input.readBytes(stateBytes, 0, lengths.get(i)); - readQuantizationStateInfos.put(fieldNames.get(i), stateBytes); + String fieldName = state.fieldInfos.fieldInfo(fieldNumbers.get(i)).getName(); + readQuantizationStateInfos.put(fieldName, stateBytes); } } return readQuantizationStateInfos; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java index cfe88b8d00..4b2ea6c33f 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentReadState; @@ -54,10 +55,16 @@ public void testReadFromSegmentReadState() { IndexInput input = Mockito.mock(IndexInput.class); Mockito.when(directory.openInput(any(), any())).thenReturn(input); + String fieldName = "test-field"; + FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + Mockito.when(fieldInfo.getName()).thenReturn(fieldName); + Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); + final SegmentReadState segmentReadState = new SegmentReadState( directory, segmentInfo, - Mockito.mock(FieldInfos.class), + fieldInfos, Mockito.mock(IOContext.class), segmentSuffix ); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java index f708a30e6c..4f018d1add 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java @@ -178,8 +178,7 @@ public void testWriteFooter() { try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { quantizationStateWriter.writeFooter(); - Mockito.verify(output, times(4)).writeInt(anyInt()); - Mockito.verify(output, times(2)).writeString(anyString()); + Mockito.verify(output, times(6)).writeInt(anyInt()); Mockito.verify(output, times(2)).writeVLong(anyLong()); Mockito.verify(output, times(1)).writeLong(anyLong()); mockedStaticCodecUtil.verify(() -> CodecUtil.writeFooter(output)); From 9b486d8bc3ceb3a061546979dab3f42e8c4b677d Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Mon, 26 Aug 2024 15:06:35 -0700 Subject: [PATCH 17/41] Integrate with query flow Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReader.java | 16 +++++++-- .../opensearch/knn/index/query/KNNWeight.java | 36 +++++++++++++++++++ .../QuantizationStateReadConfig.java | 2 ++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index 8e2044a563..c30ada74b7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -12,6 +12,9 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; @@ -121,10 +124,17 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr input.seek(position); byte[] stateBytes = new byte[length]; input.readBytes(stateBytes, 0, length); + // Deserialize the byte array to a quantization state object + ScalarQuantizationType scalarQuantizationType = readConfig.getScalarQuantizationType(); + if (scalarQuantizationType == ScalarQuantizationType.ONE_BIT) { + return OneBitScalarQuantizationState.fromByteArray(stateBytes); + } else if (scalarQuantizationType == ScalarQuantizationType.TWO_BIT + || scalarQuantizationType == ScalarQuantizationType.FOUR_BIT) { + return MultiBitScalarQuantizationState.fromByteArray(stateBytes); + } else { + throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); + } } - // Deserialize the byte array to a quantization state object - // TODO: Get params from field info and deserialize - return null; } @VisibleForTesting diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 6cc1108390..027425356a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -7,6 +7,7 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; @@ -28,17 +29,24 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNN990Codec.KNNQuantizationStateReader; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; +import org.opensearch.knn.quantization.factory.QuantizerFactory; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; +import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; import java.nio.file.Path; @@ -230,6 +238,16 @@ private Map doANNSearch( // TODO: Use this to get quantization config QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo); + QuantizationState quantizationState = KNNQuantizationStateReader.read( + new QuantizationStateReadConfig( + reader.directory(), + knnQuery.getField(), + Long.toString(reader.getSegmentInfo().getFieldInfosGen(), Character.MAX_RADIX), + fieldInfo, + quantizationConfig.getQuantizationType() + ) + ); + KNNEngine knnEngine; SpaceType spaceType; VectorDataType vectorDataType; @@ -256,6 +274,20 @@ private Map doANNSearch( ); } + if (quantizationState != null) { + Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationState.getQuantizationParams()); + QuantizationOutput quantizationOutput = QuantizationService.getInstance() + .createQuantizationOutput(quantizationState.getQuantizationParams()); + KnnVectorsReader vectorReader = reader.getVectorReader(); + if (vectorDataType == VectorDataType.FLOAT) { + quantizer.quantize(vectorReader.getFloatVectorValues(knnQuery.getField()), quantizationState, quantizationOutput); + } else if (vectorDataType == VectorDataType.BINARY || vectorDataType == VectorDataType.BYTE) { + quantizer.quantize(vectorReader.getByteVectorValues(knnQuery.getField()), quantizationState, quantizationOutput); + } else { + throw new IllegalArgumentException(String.format("Unexpected vector data type: %s", vectorDataType)); + } + } + List engineFiles = getEngineFiles(reader, knnEngine.getExtension()); if (engineFiles.isEmpty()) { log.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); @@ -354,6 +386,10 @@ private Map doANNSearch( .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } + private void quantize(DocIdSetIterator vectorValues) { + + } + @VisibleForTesting List getEngineFiles(SegmentReader reader, String extension) throws IOException { /* diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java index 96a126a1e7..973e833ab1 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java @@ -9,6 +9,7 @@ import lombok.Getter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.store.Directory; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; @Getter @AllArgsConstructor @@ -17,4 +18,5 @@ public class QuantizationStateReadConfig { private String segmentName; private String segmentSuffix; private FieldInfo fieldInfo; + private ScalarQuantizationType scalarQuantizationType; } From 0a7d80e10637dbc6c3189ff0a6988d060db484a1 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Mon, 26 Aug 2024 15:15:26 -0700 Subject: [PATCH 18/41] Remove duplicate writeFooter Signed-off-by: Ryan Bogan --- .../codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 7875121de5..f19439b3a3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -88,7 +88,6 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field ); } - quantizationStateWriter.writeFooter(); } @Override From de89987f65652af342127cdb6d35fb77325f586a Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Mon, 26 Aug 2024 15:19:59 -0700 Subject: [PATCH 19/41] Integrate with cache Signed-off-by: Ryan Bogan --- .../opensearch/knn/index/query/KNNWeight.java | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 027425356a..097677322d 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -45,6 +45,7 @@ import org.opensearch.knn.quantization.factory.QuantizerFactory; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import org.opensearch.knn.quantization.quantizer.Quantizer; @@ -238,15 +239,20 @@ private Map doANNSearch( // TODO: Use this to get quantization config QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo); - QuantizationState quantizationState = KNNQuantizationStateReader.read( - new QuantizationStateReadConfig( - reader.directory(), - knnQuery.getField(), - Long.toString(reader.getSegmentInfo().getFieldInfosGen(), Character.MAX_RADIX), - fieldInfo, - quantizationConfig.getQuantizationType() - ) - ); + QuantizationState quantizationState = QuantizationStateCache.getInstance().getQuantizationState(knnQuery.getField()); + + if (quantizationState == null) { + quantizationState = KNNQuantizationStateReader.read( + new QuantizationStateReadConfig( + reader.directory(), + knnQuery.getField(), + Long.toString(reader.getSegmentInfo().getFieldInfosGen(), Character.MAX_RADIX), + fieldInfo, + quantizationConfig.getQuantizationType() + ) + ); + QuantizationStateCache.getInstance().addQuantizationState(knnQuery.getField(), quantizationState); + } KNNEngine knnEngine; SpaceType spaceType; @@ -386,10 +392,6 @@ private Map doANNSearch( .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } - private void quantize(DocIdSetIterator vectorValues) { - - } - @VisibleForTesting List getEngineFiles(SegmentReader reader, String extension) throws IOException { /* From 59be504591cc53fc75bd5be279933d4e84e9a5fb Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 27 Aug 2024 12:30:59 -0700 Subject: [PATCH 20/41] Change implementation and fix tests Signed-off-by: Ryan Bogan --- .../org/opensearch/knn/index/KNNSettings.java | 8 +- .../KNNQuantizationStateReader.java | 4 +- .../opensearch/knn/index/query/KNNWeight.java | 62 ++++++-------- .../QuantizationStateCache.java | 14 ++-- .../QuantizationStateCacheManager.java | 81 +++++++++++++++++++ .../QuantizationStateReadConfig.java | 3 +- ...NativeEngines990KnnVectorsFormatTests.java | 19 ++++- 7 files changed, 136 insertions(+), 55 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 73f43d3d1e..aa3581912c 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -24,7 +24,7 @@ import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto; import org.opensearch.knn.index.util.IndexHyperParametersUtil; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.os.OsProbe; @@ -396,11 +396,11 @@ private void setSettingsUpdateConsumers() { NativeMemoryCacheManager.getInstance().rebuildCache(builder.build()); }, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList())); clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, it -> { - QuantizationStateCache.getInstance().setMaxCacheSizeInKB(it.getKb()); - QuantizationStateCache.getInstance().rebuildCache(); + QuantizationStateCacheManager.getInstance().setMaxCacheSizeInKB(it.getKb()); + QuantizationStateCacheManager.getInstance().rebuildCache(); }); clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, it -> { - QuantizationStateCache.getInstance().rebuildCache(); + QuantizationStateCacheManager.getInstance().rebuildCache(); }); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index c30ada74b7..bad3c23574 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -125,7 +125,9 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr byte[] stateBytes = new byte[length]; input.readBytes(stateBytes, 0, length); // Deserialize the byte array to a quantization state object - ScalarQuantizationType scalarQuantizationType = readConfig.getScalarQuantizationType(); + ScalarQuantizationType scalarQuantizationType = ScalarQuantizationType.fromId( + Integer.parseInt(readConfig.getScalarQuantizationTypeId()) + ); if (scalarQuantizationType == ScalarQuantizationType.ONE_BIT) { return OneBitScalarQuantizationState.fromByteArray(stateBytes); } else if (scalarQuantizationType == ScalarQuantizationType.TWO_BIT diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 097677322d..0da0503443 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -7,7 +7,6 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; -import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; @@ -24,13 +23,10 @@ import org.apache.lucene.util.FixedBitSet; import org.opensearch.common.io.PathUtils; import org.opensearch.common.lucene.Lucene; -import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.KNN990Codec.KNNQuantizationStateReader; -import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; @@ -42,12 +38,11 @@ import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; -import org.opensearch.knn.quantization.factory.QuantizerFactory; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; -import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; import java.nio.file.Path; @@ -81,6 +76,7 @@ public class KNNWeight extends Weight { private final ExactSearcher exactSearcher; private static ExactSearcher DEFAULT_EXACT_SEARCHER; + private final QuantizationService quantizationService = QuantizationService.getInstance(); public KNNWeight(KNNQuery query, float boost) { super(query); @@ -236,24 +232,6 @@ private Map doANNSearch( return null; } - // TODO: Use this to get quantization config - QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo); - - QuantizationState quantizationState = QuantizationStateCache.getInstance().getQuantizationState(knnQuery.getField()); - - if (quantizationState == null) { - quantizationState = KNNQuantizationStateReader.read( - new QuantizationStateReadConfig( - reader.directory(), - knnQuery.getField(), - Long.toString(reader.getSegmentInfo().getFieldInfosGen(), Character.MAX_RADIX), - fieldInfo, - quantizationConfig.getQuantizationType() - ) - ); - QuantizationStateCache.getInstance().addQuantizationState(knnQuery.getField(), quantizationState); - } - KNNEngine knnEngine; SpaceType spaceType; VectorDataType vectorDataType; @@ -280,18 +258,23 @@ private Map doANNSearch( ); } - if (quantizationState != null) { - Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationState.getQuantizationParams()); - QuantizationOutput quantizationOutput = QuantizationService.getInstance() - .createQuantizationOutput(quantizationState.getQuantizationParams()); - KnnVectorsReader vectorReader = reader.getVectorReader(); - if (vectorDataType == VectorDataType.FLOAT) { - quantizer.quantize(vectorReader.getFloatVectorValues(knnQuery.getField()), quantizationState, quantizationOutput); - } else if (vectorDataType == VectorDataType.BINARY || vectorDataType == VectorDataType.BYTE) { - quantizer.quantize(vectorReader.getByteVectorValues(knnQuery.getField()), quantizationState, quantizationOutput); - } else { - throw new IllegalArgumentException(String.format("Unexpected vector data type: %s", vectorDataType)); - } + QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); + + byte[] quantizedVector = null; + + if (quantizationParams != null) { + QuantizationState quantizationState = QuantizationStateCacheManager.getInstance() + .getQuantizationState( + new QuantizationStateReadConfig( + reader.directory(), + reader.getSegmentName(), + Long.toString(reader.getSegmentInfo().getFieldInfosGen(), Character.MAX_RADIX), + fieldInfo, + quantizationParams.getTypeIdentifier() + ) + ); + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams); + quantizedVector = (byte[]) quantizationService.quantize(quantizationState, knnQuery.getQueryVector(), quantizationOutput); } List engineFiles = getEngineFiles(reader, knnEngine.getExtension()); @@ -334,10 +317,11 @@ private Map doANNSearch( } int[] parentIds = getParentIdsArray(context); if (k > 0) { - if (knnQuery.getVectorDataType() == VectorDataType.BINARY) { + if (knnQuery.getVectorDataType() == VectorDataType.BINARY + || quantizationParams != null && quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) { results = JNIService.queryBinaryIndex( indexAllocation.getMemoryAddress(), - knnQuery.getByteQueryVector(), + quantizationParams == null ? knnQuery.getByteQueryVector() : quantizedVector, k, knnQuery.getMethodParameters(), knnEngine, diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java index ba26d517d0..cc5a34bcd8 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java @@ -11,7 +11,6 @@ import com.google.common.cache.RemovalCause; import com.google.common.cache.RemovalNotification; import lombok.Getter; -import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.unit.ByteSizeValue; @@ -33,7 +32,6 @@ public class QuantizationStateCache { private static volatile QuantizationStateCache instance; private Cache cache; @Getter - @Setter private long maxCacheSizeInKB; @Getter private Instant evictedDueToSizeAt; @@ -48,7 +46,7 @@ public class QuantizationStateCache { * Gets the singleton instance of the cache. * @return QuantizationStateCache */ - public static QuantizationStateCache getInstance() { + static QuantizationStateCache getInstance() { if (instance == null) { synchronized (QuantizationStateCache.class) { if (instance == null) { @@ -75,7 +73,7 @@ private void buildCache() { .build(); } - public synchronized void rebuildCache() { + synchronized void rebuildCache() { clear(); buildCache(); } @@ -85,7 +83,7 @@ public synchronized void rebuildCache() { * @param fieldName The name of the field. * @return The associated QuantizationState, or null if not present. */ - public QuantizationState getQuantizationState(String fieldName) { + QuantizationState getQuantizationState(String fieldName) { return cache.getIfPresent(fieldName); } @@ -94,7 +92,7 @@ public QuantizationState getQuantizationState(String fieldName) { * @param fieldName The name of the field. * @param quantizationState The quantization state to store. */ - public void addQuantizationState(String fieldName, QuantizationState quantizationState) { + void addQuantizationState(String fieldName, QuantizationState quantizationState) { cache.put(fieldName, quantizationState); } @@ -112,6 +110,10 @@ private void onRemoval(RemovalNotification removalNot } } + void setMaxCacheSizeInKB(long maxCacheSizeInKB) { + this.maxCacheSizeInKB = maxCacheSizeInKB; + } + private void updateEvictedDueToSizeAt() { evictedDueToSizeAt = Instant.now(); } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java new file mode 100644 index 0000000000..f3f440e452 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.codec.KNN990Codec.KNNQuantizationStateReader; + +import java.io.IOException; + +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class QuantizationStateCacheManager { + + private static volatile QuantizationStateCacheManager instance; + + /** + * Gets the singleton instance of the cache. + * @return QuantizationStateCache + */ + public static QuantizationStateCacheManager getInstance() { + if (instance == null) { + synchronized (QuantizationStateCacheManager.class) { + if (instance == null) { + instance = new QuantizationStateCacheManager(); + } + } + } + return instance; + } + + public synchronized void rebuildCache() { + QuantizationStateCache.getInstance().rebuildCache(); + } + + /** + * Retrieves the quantization state associated with a given field name. Reads from cache first, then from disk if necessary. + * @param quantizationStateReadConfig information required from reading from off-heap if necessary + * @return The associated QuantizationState + */ + public QuantizationState getQuantizationState(QuantizationStateReadConfig quantizationStateReadConfig) throws IOException { + String fieldName = quantizationStateReadConfig.getFieldInfo().getName(); + QuantizationState quantizationState = QuantizationStateCache.getInstance().getQuantizationState(fieldName); + if (quantizationState == null) { + quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); + addQuantizationState(fieldName, quantizationState); + } + + return quantizationState; + } + + /** + * Adds or updates a quantization state in the cache. + * @param fieldName The name of the field. + * @param quantizationState The quantization state to store. + */ + private void addQuantizationState(String fieldName, QuantizationState quantizationState) { + QuantizationStateCache.getInstance().addQuantizationState(fieldName, quantizationState); + } + + /** + * Removes the quantization state associated with a given field name. + * @param fieldName The name of the field. + */ + public void evict(String fieldName) { + QuantizationStateCache.getInstance().evict(fieldName); + } + + public void setMaxCacheSizeInKB(long maxCacheSizeInKB) { + QuantizationStateCache.getInstance().setMaxCacheSizeInKB(maxCacheSizeInKB); + } + + /** + * Clears all entries from the cache. + */ + public void clear() { + QuantizationStateCache.getInstance().clear(); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java index 973e833ab1..5e0c729ff7 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java @@ -9,7 +9,6 @@ import lombok.Getter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.store.Directory; -import org.opensearch.knn.quantization.enums.ScalarQuantizationType; @Getter @AllArgsConstructor @@ -18,5 +17,5 @@ public class QuantizationStateReadConfig { private String segmentName; private String segmentSuffix; private FieldInfo fieldInfo; - private ScalarQuantizationType scalarQuantizationType; + private String scalarQuantizationTypeId; } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 2b933a1486..16404fed09 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -14,6 +14,7 @@ import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; @@ -40,6 +41,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.store.BaseDirectoryWrapper; import org.apache.lucene.util.Bits; @@ -47,7 +49,9 @@ import org.apache.lucene.util.Version; import org.junit.After; import org.junit.Assert; +import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.mockito.stubbing.Answer; import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; @@ -65,6 +69,10 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; + @Log4j2 public class NativeEngines990KnnVectorsFormatTests extends KNNTestCase { private static final Codec TESTING_CODEC = new UnitTestCodec(); @@ -124,9 +132,14 @@ public void testReaderAndWriter_whenValidInput_thenSuccess() { Assert.assertTrue( nativeEngines990KnnVectorsFormat.fieldsReader(mockedSegmentReadState) instanceof NativeEngines990KnnVectorsReader ); - Assert.assertTrue( - nativeEngines990KnnVectorsFormat.fieldsWriter(mockedSegmentWriteState) instanceof NativeEngines990KnnVectorsWriter - ); + try (MockedStatic mockedStaticCodecUtil = Mockito.mockStatic(CodecUtil.class)) { + mockedStaticCodecUtil.when( + () -> CodecUtil.writeIndexHeader(any(IndexOutput.class), anyString(), anyInt(), any(byte[].class), anyString()) + ).thenAnswer((Answer) invocation -> null); + Assert.assertTrue( + nativeEngines990KnnVectorsFormat.fieldsWriter(mockedSegmentWriteState) instanceof NativeEngines990KnnVectorsWriter + ); + } } @SneakyThrows From 366072f4d9f4a19d1b613f7f7c04f74a108d622b Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 27 Aug 2024 13:06:16 -0700 Subject: [PATCH 21/41] Add test for reading from QuantizationStateReadConfig Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReaderTests.java | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java index 4b2ea6c33f..7aeb0b7b44 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java @@ -20,6 +20,10 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import java.util.Map; @@ -84,6 +88,72 @@ public void testReadFromSegmentReadState() { } } + @SneakyThrows + public void testReadFromQuantizationStateReadConfig() { + Directory directory = Mockito.mock(Directory.class); + IndexInput input = Mockito.mock(IndexInput.class); + Mockito.when(directory.openInput(any(), any())).thenReturn(input); + + int fieldNumber = 4; + FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + Mockito.when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); + + String segmentName = "test-segment-name"; + String segmentSuffix = "test-segment-suffix"; + String scalarQuantizationTypeId1 = "1"; + String scalarQuantizationTypeId2 = "2"; + String scalarQuantizationTypeId4 = "4"; + String scalarQuantizationTypeIdIncorrect = "-1"; + QuantizationStateReadConfig quantizationStateReadConfig = Mockito.mock(QuantizationStateReadConfig.class); + Mockito.when(quantizationStateReadConfig.getSegmentName()).thenReturn(segmentName); + Mockito.when(quantizationStateReadConfig.getSegmentSuffix()).thenReturn(segmentSuffix); + Mockito.when(quantizationStateReadConfig.getFieldInfo()).thenReturn(fieldInfo); + Mockito.when(quantizationStateReadConfig.getDirectory()).thenReturn(directory); + Mockito.when(quantizationStateReadConfig.getScalarQuantizationTypeId()).thenReturn(scalarQuantizationTypeId1); + + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNNQuantizationStateReader.getNumFields(input)).thenReturn(2); + mockedStaticReader.when(() -> KNNQuantizationStateReader.read(quantizationStateReadConfig)).thenCallRealMethod(); + try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { + assertThrows(IllegalArgumentException.class, () -> KNNQuantizationStateReader.read(quantizationStateReadConfig)); + + mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); + Mockito.verify(input, times(4)).readInt(); + Mockito.verify(input, times(2)).readVLong(); + Mockito.verify(input, times(0)).readBytes(any(byte[].class), anyInt(), anyInt()); + Mockito.verify(input, times(0)).seek(anyLong()); + + Mockito.when(input.readInt()).thenReturn(fieldNumber); + + try (MockedStatic mockedStaticOneBit = mockStatic(OneBitScalarQuantizationState.class)) { + OneBitScalarQuantizationState oneBitScalarQuantizationState = Mockito.mock(OneBitScalarQuantizationState.class); + mockedStaticOneBit.when(() -> OneBitScalarQuantizationState.fromByteArray(any(byte[].class))) + .thenReturn(oneBitScalarQuantizationState); + QuantizationState quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); + assertTrue(quantizationState instanceof OneBitScalarQuantizationState); + } + + try (MockedStatic mockedStaticOneBit = mockStatic(MultiBitScalarQuantizationState.class)) { + MultiBitScalarQuantizationState multiBitScalarQuantizationState = Mockito.mock(MultiBitScalarQuantizationState.class); + mockedStaticOneBit.when(() -> MultiBitScalarQuantizationState.fromByteArray(any(byte[].class))) + .thenReturn(multiBitScalarQuantizationState); + + Mockito.when(quantizationStateReadConfig.getScalarQuantizationTypeId()).thenReturn(scalarQuantizationTypeId2); + QuantizationState quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); + assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); + + Mockito.when(quantizationStateReadConfig.getScalarQuantizationTypeId()).thenReturn(scalarQuantizationTypeId4); + quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); + assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); + } + Mockito.when(quantizationStateReadConfig.getScalarQuantizationTypeId()).thenReturn(scalarQuantizationTypeIdIncorrect); + assertThrows(IllegalArgumentException.class, () -> KNNQuantizationStateReader.read(quantizationStateReadConfig)); + } + } + } + @SneakyThrows public void testGetNumFields() { IndexInput input = Mockito.mock(IndexInput.class); From 077a1b6bbdae87e4db2a3c83b560685e1d2b89e7 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 28 Aug 2024 11:27:05 -0700 Subject: [PATCH 22/41] Add cache manager tests Signed-off-by: Ryan Bogan --- .../QuantizationStateCacheManager.java | 2 +- .../QuantizationStateCacheManagerTests.java | 100 ++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index f3f440e452..2b0d72dd86 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -56,7 +56,7 @@ public QuantizationState getQuantizationState(QuantizationStateReadConfig quanti * @param fieldName The name of the field. * @param quantizationState The quantization state to store. */ - private void addQuantizationState(String fieldName, QuantizationState quantizationState) { + public void addQuantizationState(String fieldName, QuantizationState quantizationState) { QuantizationStateCache.getInstance().addQuantizationState(fieldName, quantizationState); } diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java new file mode 100644 index 0000000000..084baeae46 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNN990Codec.KNNQuantizationStateReader; + +import static org.mockito.Mockito.times; + +public class QuantizationStateCacheManagerTests extends KNNTestCase { + + @SneakyThrows + public void testRebuildCache() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).rebuildCache(); + QuantizationStateCacheManager.getInstance().rebuildCache(); + Mockito.verify(quantizationStateCache, times(1)).rebuildCache(); + } + } + + @SneakyThrows + public void testGetQuantizationState() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + QuantizationStateReadConfig quantizationStateReadConfig = Mockito.mock(QuantizationStateReadConfig.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + String fieldName = "test-field"; + Mockito.when(fieldInfo.getName()).thenReturn(fieldName); + Mockito.when(quantizationStateReadConfig.getFieldInfo()).thenReturn(fieldInfo); + QuantizationState quantizationState = Mockito.mock(QuantizationState.class); + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).addQuantizationState(fieldName, quantizationState); + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNNQuantizationStateReader.read(quantizationStateReadConfig)).thenReturn(quantizationState); + QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); + Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(fieldName, quantizationState); + } + Mockito.when(quantizationStateCache.getQuantizationState(fieldName)).thenReturn(quantizationState); + QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); + Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(fieldName, quantizationState); + } + } + + @SneakyThrows + public void testEvict() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + String field = "test-field"; + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).evict(field); + QuantizationStateCacheManager.getInstance().evict(field); + Mockito.verify(quantizationStateCache, times(1)).evict(field); + } + } + + @SneakyThrows + public void testAddQuantizationState() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + String field = "test-field"; + QuantizationState quantizationState = Mockito.mock(QuantizationState.class); + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).addQuantizationState(field, quantizationState); + QuantizationStateCacheManager.getInstance().addQuantizationState(field, quantizationState); + Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(field, quantizationState); + } + } + + @SneakyThrows + public void testSetMaxCacheSizeInKB() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + long maxCacheSizeInKB = 1024; + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).setMaxCacheSizeInKB(maxCacheSizeInKB); + QuantizationStateCacheManager.getInstance().setMaxCacheSizeInKB(1024); + Mockito.verify(quantizationStateCache, times(1)).setMaxCacheSizeInKB(1024); + } + } + + @SneakyThrows + public void testClear() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).clear(); + QuantizationStateCacheManager.getInstance().clear(); + Mockito.verify(quantizationStateCache, times(1)).clear(); + } + } +} From 3dbbad962c6b8434831a22d9571ac54c078041ef Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 28 Aug 2024 13:48:53 -0700 Subject: [PATCH 23/41] Port changes from feature branch to fix end to end flow Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReader.java | 17 ++--- .../NativeEngines990KnnVectorsReader.java | 6 ++ .../QuantizationConfigKNNCollector.java | 59 +++++++++++++++++ .../opensearch/knn/index/query/KNNQuery.java | 8 ++- .../opensearch/knn/index/query/KNNWeight.java | 40 +++++++++--- .../QuantizationStateCacheManager.java | 6 +- .../QuantizationStateReadConfig.java | 13 ++-- .../KNNQuantizationStateReaderTests.java | 63 +++++++++++++------ .../knn/index/query/KNNWeightTests.java | 60 +++++++++++++----- .../QuantizationStateCacheManagerTests.java | 15 ++--- 10 files changed, 214 insertions(+), 73 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index bad3c23574..90afc217b1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -13,6 +13,7 @@ import org.apache.lucene.store.IndexInput; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -91,14 +92,16 @@ public static Map read(SegmentReadState state) throws IOExceptio * @return quantization state */ public static QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { + SegmentReadState segmentReadState = readConfig.getSegmentReadState(); + String field = readConfig.getField(); String quantizationStateFileName = IndexFileNames.segmentFileName( - readConfig.getSegmentName(), - readConfig.getSegmentSuffix(), + segmentReadState.segmentInfo.name, + segmentReadState.segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX ); - int fieldNumber = readConfig.getFieldInfo().getFieldNumber(); + int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); - try (IndexInput input = readConfig.getDirectory().openInput(quantizationStateFileName, IOContext.READ)) { + try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) { CodecUtil.retrieveChecksum(input); int numFields = getNumFields(input); @@ -118,16 +121,14 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr } if (position == -1 || length == 0) { - throw new IllegalArgumentException(String.format("Field %s not found", readConfig.getFieldInfo().getName())); + throw new IllegalArgumentException(String.format("Field %s not found", field)); } input.seek(position); byte[] stateBytes = new byte[length]; input.readBytes(stateBytes, 0, length); // Deserialize the byte array to a quantization state object - ScalarQuantizationType scalarQuantizationType = ScalarQuantizationType.fromId( - Integer.parseInt(readConfig.getScalarQuantizationTypeId()) - ); + ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType(); if (scalarQuantizationType == ScalarQuantizationType.ONE_BIT) { return OneBitScalarQuantizationState.fromByteArray(stateBytes); } else if (scalarQuantizationType == ScalarQuantizationType.TWO_BIT diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 74b158fa5a..2349f6ef4c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -33,8 +33,10 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { private final FlatVectorsReader flatVectorsReader; + private final SegmentReadState segmentReadState; public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) { + this.segmentReadState = state; this.flatVectorsReader = flatVectorsReader; } @@ -101,6 +103,10 @@ public ByteVectorValues getByteVectorValues(final String field) throws IOExcepti */ @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + if (knnCollector instanceof QuantizationConfigKNNCollector) { + ((QuantizationConfigKNNCollector) knnCollector).setSegmentReadState(segmentReadState); + return; + } throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java new file mode 100644 index 0000000000..ae7b9a74ab --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.Getter; +import lombok.Setter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; + +@Setter +@Getter +public class QuantizationConfigKNNCollector implements KnnCollector { + + private SegmentReadState segmentReadState; + + @Override + public boolean earlyTerminated() { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } + + @Override + public void incVisitedCount(int i) { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } + + @Override + public long visitedCount() { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } + + @Override + public long visitLimit() { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } + + @Override + public int k() { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } + + @Override + public boolean collect(int i, float v) { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } + + @Override + public float minCompetitiveSimilarity() { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } + + @Override + public TopDocs topDocs() { + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 04a10143cd..3d1a83f5c6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -45,6 +45,8 @@ public class KNNQuery extends Query { private final String indexName; private final VectorDataType vectorDataType; private final RescoreContext rescoreContext; + private final String indexUUID; + private final int shardId; @Setter private Query filterQuery; @@ -107,6 +109,8 @@ private KNNQuery( this.parentsFilter = parentsFilter; this.vectorDataType = vectorDataType; this.rescoreContext = rescoreContext; + this.indexUUID = null; + this.shardId = -1; } /** @@ -169,9 +173,9 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo } final Weight filterWeight = getFilterWeight(searcher); if (filterWeight != null) { - return new KNNWeight(this, boost, filterWeight); + return new KNNWeight(this, boost, filterWeight, indexUUID, shardId); } - return new KNNWeight(this, boost); + return new KNNWeight(this, boost, indexUUID, shardId); } private Weight getFilterWeight(IndexSearcher searcher) throws IOException { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 0da0503443..5f9cf4e5a2 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -8,6 +8,7 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; import org.apache.lucene.search.DocIdSetIterator; @@ -27,6 +28,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; @@ -77,23 +79,29 @@ public class KNNWeight extends Weight { private static ExactSearcher DEFAULT_EXACT_SEARCHER; private final QuantizationService quantizationService = QuantizationService.getInstance(); + private final String indexUUID; + private final int shardId; - public KNNWeight(KNNQuery query, float boost) { + public KNNWeight(KNNQuery query, float boost, String indexUUID, int shardId) { super(query); this.knnQuery = query; this.boost = boost; this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); this.filterWeight = null; this.exactSearcher = DEFAULT_EXACT_SEARCHER; + this.indexUUID = indexUUID; + this.shardId = shardId; } - public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { + public KNNWeight(KNNQuery query, float boost, Weight filterWeight, String indexUUID, int shardId) { super(query); this.knnQuery = query; this.boost = boost; this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); this.filterWeight = filterWeight; this.exactSearcher = DEFAULT_EXACT_SEARCHER; + this.indexUUID = indexUUID; + this.shardId = shardId; } public static void initialize(ModelDao modelDao) { @@ -216,13 +224,18 @@ private int[] bitSetToIntArray(final BitSet bitSet) { return intArray; } + private String createQCacheKey(String segmentName) { + return String.format("%s_%s_%s_%s", indexUUID, shardId, segmentName, knnQuery.getField()); + } + private Map doANNSearch( final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k ) throws IOException { - final SegmentReader reader = Lucene.segmentReader(context.reader()); + LeafReader reader2 = context.reader(); + final SegmentReader reader = Lucene.segmentReader(reader2); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); @@ -263,14 +276,18 @@ private Map doANNSearch( byte[] quantizedVector = null; if (quantizationParams != null) { + QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); + reader.searchNearestVectors(knnQuery.getField(), new float[0], tempCollector, null); + if (tempCollector.getSegmentReadState() == null) { + throw new IllegalStateException("No quantization state for file"); + } QuantizationState quantizationState = QuantizationStateCacheManager.getInstance() .getQuantizationState( new QuantizationStateReadConfig( - reader.directory(), - reader.getSegmentName(), - Long.toString(reader.getSegmentInfo().getFieldInfosGen(), Character.MAX_RADIX), - fieldInfo, - quantizationParams.getTypeIdentifier() + tempCollector.getSegmentReadState(), + quantizationParams, + knnQuery.getField(), + createQCacheKey(reader.getSegmentName()) ) ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams); @@ -294,7 +311,12 @@ private Map doANNSearch( new NativeMemoryEntryContext.IndexEntryContext( indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName(), vectorDataType), + getParametersAtLoading( + spaceType, + knnEngine, + knnQuery.getIndexName(), + quantizationParams == null ? vectorDataType : VectorDataType.BINARY + ), knnQuery.getIndexName(), modelId ), diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index 2b0d72dd86..5873153134 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -41,11 +41,11 @@ public synchronized void rebuildCache() { * @return The associated QuantizationState */ public QuantizationState getQuantizationState(QuantizationStateReadConfig quantizationStateReadConfig) throws IOException { - String fieldName = quantizationStateReadConfig.getFieldInfo().getName(); - QuantizationState quantizationState = QuantizationStateCache.getInstance().getQuantizationState(fieldName); + QuantizationState quantizationState = QuantizationStateCache.getInstance() + .getQuantizationState(quantizationStateReadConfig.getCacheKey()); if (quantizationState == null) { quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); - addQuantizationState(fieldName, quantizationState); + addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); } return quantizationState; diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java index 5e0c729ff7..d13e4f3f52 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java @@ -7,15 +7,14 @@ import lombok.AllArgsConstructor; import lombok.Getter; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.store.Directory; +import org.apache.lucene.index.SegmentReadState; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; @Getter @AllArgsConstructor public class QuantizationStateReadConfig { - private Directory directory; - private String segmentName; - private String segmentSuffix; - private FieldInfo fieldInfo; - private String scalarQuantizationTypeId; + private SegmentReadState segmentReadState; + private QuantizationParams quantizationParams; + private String field; + private String cacheKey; } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java index 7aeb0b7b44..d3c2f75d1f 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java @@ -20,6 +20,8 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -90,28 +92,48 @@ public void testReadFromSegmentReadState() { @SneakyThrows public void testReadFromQuantizationStateReadConfig() { - Directory directory = Mockito.mock(Directory.class); - IndexInput input = Mockito.mock(IndexInput.class); - Mockito.when(directory.openInput(any(), any())).thenReturn(input); - + String fieldName = "test-field"; int fieldNumber = 4; FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); - Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); + Mockito.when(fieldInfos.fieldInfo(fieldName)).thenReturn(fieldInfo); + + final String segmentName = "test-segment-name"; + final String segmentSuffix = "test-segment-suffix"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexInput input = Mockito.mock(IndexInput.class); + Mockito.when(directory.openInput(any(), any())).thenReturn(input); - String segmentName = "test-segment-name"; - String segmentSuffix = "test-segment-suffix"; - String scalarQuantizationTypeId1 = "1"; - String scalarQuantizationTypeId2 = "2"; - String scalarQuantizationTypeId4 = "4"; - String scalarQuantizationTypeIdIncorrect = "-1"; + final SegmentReadState segmentReadState = new SegmentReadState( + directory, + segmentInfo, + fieldInfos, + Mockito.mock(IOContext.class), + segmentSuffix + ); + ScalarQuantizationParams scalarQuantizationParams1 = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams scalarQuantizationParams2 = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams scalarQuantizationParams4 = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); QuantizationStateReadConfig quantizationStateReadConfig = Mockito.mock(QuantizationStateReadConfig.class); - Mockito.when(quantizationStateReadConfig.getSegmentName()).thenReturn(segmentName); - Mockito.when(quantizationStateReadConfig.getSegmentSuffix()).thenReturn(segmentSuffix); - Mockito.when(quantizationStateReadConfig.getFieldInfo()).thenReturn(fieldInfo); - Mockito.when(quantizationStateReadConfig.getDirectory()).thenReturn(directory); - Mockito.when(quantizationStateReadConfig.getScalarQuantizationTypeId()).thenReturn(scalarQuantizationTypeId1); + Mockito.when(quantizationStateReadConfig.getSegmentReadState()).thenReturn(segmentReadState); + Mockito.when(quantizationStateReadConfig.getField()).thenReturn(fieldName); try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { mockedStaticReader.when(() -> KNNQuantizationStateReader.getNumFields(input)).thenReturn(2); @@ -128,6 +150,7 @@ public void testReadFromQuantizationStateReadConfig() { Mockito.when(input.readInt()).thenReturn(fieldNumber); try (MockedStatic mockedStaticOneBit = mockStatic(OneBitScalarQuantizationState.class)) { + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams1); OneBitScalarQuantizationState oneBitScalarQuantizationState = Mockito.mock(OneBitScalarQuantizationState.class); mockedStaticOneBit.when(() -> OneBitScalarQuantizationState.fromByteArray(any(byte[].class))) .thenReturn(oneBitScalarQuantizationState); @@ -140,16 +163,16 @@ public void testReadFromQuantizationStateReadConfig() { mockedStaticOneBit.when(() -> MultiBitScalarQuantizationState.fromByteArray(any(byte[].class))) .thenReturn(multiBitScalarQuantizationState); - Mockito.when(quantizationStateReadConfig.getScalarQuantizationTypeId()).thenReturn(scalarQuantizationTypeId2); + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams2); + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams2); QuantizationState quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); - Mockito.when(quantizationStateReadConfig.getScalarQuantizationTypeId()).thenReturn(scalarQuantizationTypeId4); + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams4); + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams4); quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); } - Mockito.when(quantizationStateReadConfig.getScalarQuantizationTypeId()).thenReturn(scalarQuantizationTypeIdIncorrect); - assertThrows(IllegalArgumentException.class, () -> KNNQuantizationStateReader.read(quantizationStateReadConfig)); } } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 249ae04cec..63fbe67a68 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -230,7 +230,9 @@ public void testQueryScoreForFaissWithModel() { KNNWeight.initialize(modelDao); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, indexUUID, shardId); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -294,7 +296,9 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { when(modelMetadata.getSpaceType()).thenReturn(spaceType); KNNWeight.initialize(modelDao); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, indexUUID, shardId); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -319,7 +323,9 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { @SneakyThrows public void testShardWithoutFiles() { final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, indexUUID, shardId); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -364,7 +370,9 @@ public void testEmptyQueryResults() { .thenReturn(knnQueryResults); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, indexUUID, shardId); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -455,7 +463,9 @@ private void validateScorer_whenNoFilter_thenSuccess(final boolean isBinary) thr .build(); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, indexUUID, shardId); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); final Map attributesMap = ImmutableMap.of( @@ -581,7 +591,9 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -699,7 +711,9 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean when(liveDocsBits.get(filterDocId)).thenReturn(true); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), @@ -775,7 +789,9 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS when(liveDocsBits.get(filterDocId)).thenReturn(true); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), @@ -845,7 +861,9 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null, null); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), @@ -924,7 +942,9 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSettingOnBinaryInd ); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), @@ -976,7 +996,9 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { when(filterScorer.iterator()).thenReturn(DocIdSetIterator.empty()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null, null); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight, indexUUID, shardId); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -1025,7 +1047,9 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, parentFilter, null); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); // Execute final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); @@ -1064,7 +1088,9 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { .parentsFilter(bitSetProducer) .build(); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, null); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, indexUUID, shardId); jniServiceMockedStatic.when( () -> JNIService.queryIndex( @@ -1131,7 +1157,9 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { .methodParameters(HNSW_METHOD_PARAMETERS) .build(); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, indexUUID, shardId); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -1273,7 +1301,9 @@ private void testQueryScore( .methodParameters(HNSW_METHOD_PARAMETERS) .build(); final float boost = (float) randomDoubleBetween(0, 10, true); - final KNNWeight knnWeight = new KNNWeight(query, boost); + final String indexUUID = null; + final int shardId = -1; + final KNNWeight knnWeight = new KNNWeight(query, boost, indexUUID, shardId); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java index 084baeae46..d1d3c329ad 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java @@ -6,7 +6,6 @@ package org.opensearch.knn.quantization.models.quantizationState; import lombok.SneakyThrows; -import org.apache.lucene.index.FieldInfo; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; @@ -31,22 +30,20 @@ public void testRebuildCache() { public void testGetQuantizationState() { try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { QuantizationStateReadConfig quantizationStateReadConfig = Mockito.mock(QuantizationStateReadConfig.class); - FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); - String fieldName = "test-field"; - Mockito.when(fieldInfo.getName()).thenReturn(fieldName); - Mockito.when(quantizationStateReadConfig.getFieldInfo()).thenReturn(fieldInfo); + String cacheKey = "test-key"; + Mockito.when(quantizationStateReadConfig.getCacheKey()).thenReturn(cacheKey); QuantizationState quantizationState = Mockito.mock(QuantizationState.class); QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); - Mockito.doNothing().when(quantizationStateCache).addQuantizationState(fieldName, quantizationState); + Mockito.doNothing().when(quantizationStateCache).addQuantizationState(cacheKey, quantizationState); try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { mockedStaticReader.when(() -> KNNQuantizationStateReader.read(quantizationStateReadConfig)).thenReturn(quantizationState); QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); - Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(fieldName, quantizationState); + Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(cacheKey, quantizationState); } - Mockito.when(quantizationStateCache.getQuantizationState(fieldName)).thenReturn(quantizationState); + Mockito.when(quantizationStateCache.getQuantizationState(cacheKey)).thenReturn(quantizationState); QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); - Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(fieldName, quantizationState); + Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(cacheKey, quantizationState); } } From 45b6fbf773bde5a5d7241e9a2a3ef81db96a67bb Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 29 Aug 2024 14:48:42 -0700 Subject: [PATCH 24/41] Change integration with query flow Signed-off-by: Ryan Bogan --- .../NativeEngines990KnnVectorsReader.java | 54 ++++++++++++++++++- .../QuantizationConfigKNNCollector.java | 4 +- .../opensearch/knn/index/query/KNNWeight.java | 24 +++------ ...NativeEngines990KnnVectorsFormatTests.java | 33 ++++++++++-- 4 files changed, 88 insertions(+), 27 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 2349f6ef4c..fe4d39c2cb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -23,8 +23,19 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; +import org.opensearch.common.UUIDs; +import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; /** * Vectors reader class for reading the flat vectors for native engines. The class provides methods for iterating @@ -34,9 +45,34 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { private final FlatVectorsReader flatVectorsReader; private final SegmentReadState segmentReadState; + private Map fieldToUniqueCacheId; - public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) { + public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException { this.segmentReadState = state; + fieldToUniqueCacheId = new HashMap<>(); + Map stateMap = KNNQuantizationStateReader.read(segmentReadState); + for (Map.Entry entry : stateMap.entrySet()) { + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(entry.getKey()); + QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); + if (quantizationParams instanceof ScalarQuantizationParams) { + QuantizationState quantizationState; + ScalarQuantizationParams scalarQuantizationParams = (ScalarQuantizationParams) quantizationParams; + switch (scalarQuantizationParams.getSqType()) { + case ONE_BIT: + quantizationState = OneBitScalarQuantizationState.fromByteArray(entry.getValue()); + break; + case TWO_BIT: + case FOUR_BIT: + quantizationState = MultiBitScalarQuantizationState.fromByteArray(entry.getValue()); + break; + default: + throw new IllegalArgumentException("Unknown Scalar Quantization Type"); + } + String cacheKey = UUIDs.base64UUID(); + fieldToUniqueCacheId.put(entry.getKey(), cacheKey); + QuantizationStateCacheManager.getInstance().addQuantizationState(cacheKey, quantizationState); + } + } this.flatVectorsReader = flatVectorsReader; } @@ -104,7 +140,18 @@ public ByteVectorValues getByteVectorValues(final String field) throws IOExcepti @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { if (knnCollector instanceof QuantizationConfigKNNCollector) { - ((QuantizationConfigKNNCollector) knnCollector).setSegmentReadState(segmentReadState); + String cacheKey = fieldToUniqueCacheId.get(field); + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(field); + QuantizationState quantizationState = QuantizationStateCacheManager.getInstance() + .getQuantizationState( + new QuantizationStateReadConfig( + segmentReadState, + QuantizationService.getInstance().getQuantizationParams(fieldInfo), + field, + cacheKey + ) + ); + ((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(quantizationState); return; } throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); @@ -156,6 +203,9 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits @Override public void close() throws IOException { IOUtils.close(flatVectorsReader); + for (String cacheKey : fieldToUniqueCacheId.values()) { + QuantizationStateCacheManager.getInstance().evict(cacheKey); + } } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java index ae7b9a74ab..6dabac4a41 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java @@ -7,15 +7,15 @@ import lombok.Getter; import lombok.Setter; -import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @Setter @Getter public class QuantizationConfigKNNCollector implements KnnCollector { - private SegmentReadState segmentReadState; + private QuantizationState quantizationState; @Override public boolean earlyTerminated() { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 5f9cf4e5a2..062b43bc70 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -42,9 +42,6 @@ import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import java.io.IOException; import java.nio.file.Path; @@ -224,10 +221,6 @@ private int[] bitSetToIntArray(final BitSet bitSet) { return intArray; } - private String createQCacheKey(String segmentName) { - return String.format("%s_%s_%s_%s", indexUUID, shardId, segmentName, knnQuery.getField()); - } - private Map doANNSearch( final LeafReaderContext context, final BitSet filterIdsBitSet, @@ -278,20 +271,15 @@ private Map doANNSearch( if (quantizationParams != null) { QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); reader.searchNearestVectors(knnQuery.getField(), new float[0], tempCollector, null); - if (tempCollector.getSegmentReadState() == null) { + if (tempCollector.getQuantizationState() == null) { throw new IllegalStateException("No quantization state for file"); } - QuantizationState quantizationState = QuantizationStateCacheManager.getInstance() - .getQuantizationState( - new QuantizationStateReadConfig( - tempCollector.getSegmentReadState(), - quantizationParams, - knnQuery.getField(), - createQCacheKey(reader.getSegmentName()) - ) - ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams); - quantizedVector = (byte[]) quantizationService.quantize(quantizationState, knnQuery.getQueryVector(), quantizationOutput); + quantizedVector = (byte[]) quantizationService.quantize( + tempCollector.getQuantizationState(), + knnQuery.getQueryVector(), + quantizationOutput + ); } List engineFiles = getEngineFiles(reader, knnEngine.getExtension()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 16404fed09..092535b8be 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -24,6 +24,7 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexOptions; @@ -41,6 +42,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.store.BaseDirectoryWrapper; @@ -113,6 +115,26 @@ public void testReaderAndWriter_whenValidInput_thenSuccess() { Mockito.mock(Sort.class) ); + final String segmentSuffix = "test-segment-suffix"; + + Directory directory = Mockito.mock(Directory.class); + IndexInput input = Mockito.mock(IndexInput.class); + Mockito.when(directory.openInput(any(), any())).thenReturn(input); + + String fieldName = "test-field"; + FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + Mockito.when(fieldInfo.getName()).thenReturn(fieldName); + Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); + + final SegmentReadState mockedSegmentReadState = new SegmentReadState( + directory, + mockedSegmentInfo, + fieldInfos, + Mockito.mock(IOContext.class), + segmentSuffix + ); + final SegmentWriteState mockedSegmentWriteState = new SegmentWriteState( Mockito.mock(InfoStream.class), Mockito.mock(Directory.class), @@ -121,21 +143,22 @@ public void testReaderAndWriter_whenValidInput_thenSuccess() { null, Mockito.mock(IOContext.class) ); - final SegmentReadState mockedSegmentReadState = Mockito.mock(SegmentReadState.class); - Mockito.when(mockedFlatVectorsFormat.fieldsReader(mockedSegmentReadState)).thenReturn(Mockito.mock(FlatVectorsReader.class)); Mockito.when(mockedFlatVectorsFormat.fieldsWriter(mockedSegmentWriteState)).thenReturn(Mockito.mock(FlatVectorsWriter.class)); final NativeEngines990KnnVectorsFormat nativeEngines990KnnVectorsFormat = new NativeEngines990KnnVectorsFormat( mockedFlatVectorsFormat ); - Assert.assertTrue( - nativeEngines990KnnVectorsFormat.fieldsReader(mockedSegmentReadState) instanceof NativeEngines990KnnVectorsReader - ); try (MockedStatic mockedStaticCodecUtil = Mockito.mockStatic(CodecUtil.class)) { mockedStaticCodecUtil.when( () -> CodecUtil.writeIndexHeader(any(IndexOutput.class), anyString(), anyInt(), any(byte[].class), anyString()) ).thenAnswer((Answer) invocation -> null); + mockedStaticCodecUtil.when(() -> CodecUtil.retrieveChecksum(any(IndexInput.class))) + .thenAnswer((Answer) invocation -> null); + Assert.assertTrue( + nativeEngines990KnnVectorsFormat.fieldsReader(mockedSegmentReadState) instanceof NativeEngines990KnnVectorsReader + ); + Assert.assertTrue( nativeEngines990KnnVectorsFormat.fieldsWriter(mockedSegmentWriteState) instanceof NativeEngines990KnnVectorsWriter ); From 0e0eaf3c3edaf2d78f8920c448022464a9746fee Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 29 Aug 2024 14:57:29 -0700 Subject: [PATCH 25/41] Remove unnecessary changes in KNNWeight Signed-off-by: Ryan Bogan --- .../opensearch/knn/index/query/KNNQuery.java | 8 +-- .../opensearch/knn/index/query/KNNWeight.java | 10 +--- .../knn/index/query/KNNWeightTests.java | 60 +++++-------------- 3 files changed, 19 insertions(+), 59 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 3d1a83f5c6..04a10143cd 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -45,8 +45,6 @@ public class KNNQuery extends Query { private final String indexName; private final VectorDataType vectorDataType; private final RescoreContext rescoreContext; - private final String indexUUID; - private final int shardId; @Setter private Query filterQuery; @@ -109,8 +107,6 @@ private KNNQuery( this.parentsFilter = parentsFilter; this.vectorDataType = vectorDataType; this.rescoreContext = rescoreContext; - this.indexUUID = null; - this.shardId = -1; } /** @@ -173,9 +169,9 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo } final Weight filterWeight = getFilterWeight(searcher); if (filterWeight != null) { - return new KNNWeight(this, boost, filterWeight, indexUUID, shardId); + return new KNNWeight(this, boost, filterWeight); } - return new KNNWeight(this, boost, indexUUID, shardId); + return new KNNWeight(this, boost); } private Weight getFilterWeight(IndexSearcher searcher) throws IOException { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 062b43bc70..b1caeed1e6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -76,29 +76,23 @@ public class KNNWeight extends Weight { private static ExactSearcher DEFAULT_EXACT_SEARCHER; private final QuantizationService quantizationService = QuantizationService.getInstance(); - private final String indexUUID; - private final int shardId; - public KNNWeight(KNNQuery query, float boost, String indexUUID, int shardId) { + public KNNWeight(KNNQuery query, float boost) { super(query); this.knnQuery = query; this.boost = boost; this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); this.filterWeight = null; this.exactSearcher = DEFAULT_EXACT_SEARCHER; - this.indexUUID = indexUUID; - this.shardId = shardId; } - public KNNWeight(KNNQuery query, float boost, Weight filterWeight, String indexUUID, int shardId) { + public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { super(query); this.knnQuery = query; this.boost = boost; this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); this.filterWeight = filterWeight; this.exactSearcher = DEFAULT_EXACT_SEARCHER; - this.indexUUID = indexUUID; - this.shardId = shardId; } public static void initialize(ModelDao modelDao) { diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 63fbe67a68..aeb419d60d 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -230,9 +230,7 @@ public void testQueryScoreForFaissWithModel() { KNNWeight.initialize(modelDao); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -296,9 +294,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { when(modelMetadata.getSpaceType()).thenReturn(spaceType); KNNWeight.initialize(modelDao); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -323,9 +319,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { @SneakyThrows public void testShardWithoutFiles() { final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -370,9 +364,7 @@ public void testEmptyQueryResults() { .thenReturn(knnQueryResults); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -463,9 +455,7 @@ private void validateScorer_whenNoFilter_thenSuccess(final boolean isBinary) thr .build(); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); final Map attributesMap = ImmutableMap.of( @@ -591,9 +581,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -711,9 +699,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean when(liveDocsBits.get(filterDocId)).thenReturn(true); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), @@ -789,9 +775,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS when(liveDocsBits.get(filterDocId)).thenReturn(true); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), @@ -861,9 +845,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null, null); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), @@ -942,9 +924,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSettingOnBinaryInd ); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); final Map attributesMap = ImmutableMap.of( KNN_ENGINE, KNNEngine.FAISS.getName(), @@ -996,9 +976,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { when(filterScorer.iterator()).thenReturn(DocIdSetIterator.empty()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null, null); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -1047,9 +1025,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, parentFilter, null); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); // Execute final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); @@ -1088,9 +1064,7 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { .parentsFilter(bitSetProducer) .build(); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f); jniServiceMockedStatic.when( () -> JNIService.queryIndex( @@ -1157,9 +1131,7 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { .methodParameters(HNSW_METHOD_PARAMETERS) .build(); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -1301,9 +1273,7 @@ private void testQueryScore( .methodParameters(HNSW_METHOD_PARAMETERS) .build(); final float boost = (float) randomDoubleBetween(0, 10, true); - final String indexUUID = null; - final int shardId = -1; - final KNNWeight knnWeight = new KNNWeight(query, boost, indexUUID, shardId); + final KNNWeight knnWeight = new KNNWeight(query, boost); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); From 1f6710323533ab0c16d5f09ae79abbb23a029154 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Sep 2024 10:37:29 -0700 Subject: [PATCH 26/41] Address PR Feedback and fix compile error from rebase Signed-off-by: Ryan Bogan --- .../org/opensearch/knn/index/KNNSettings.java | 7 +-- .../NativeEngines990KnnVectorsReader.java | 54 ++++++++++--------- .../opensearch/knn/index/query/KNNWeight.java | 2 +- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index f2ce7ea2a7..3793db4b81 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -57,6 +57,7 @@ public class KNNSettings { private static final OsProbe osProbe = OsProbe.getInstance(); private static final int INDEX_THREAD_QTY_MAX = 32; + private static final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance(); /** * Settings name @@ -390,11 +391,11 @@ private void setSettingsUpdateConsumers() { NativeMemoryCacheManager.getInstance().rebuildCache(builder.build()); }, dynamicCacheSettings.values().stream().collect(Collectors.toUnmodifiableList())); clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, it -> { - QuantizationStateCacheManager.getInstance().setMaxCacheSizeInKB(it.getKb()); - QuantizationStateCacheManager.getInstance().rebuildCache(); + quantizationStateCacheManager.setMaxCacheSizeInKB(it.getKb()); + quantizationStateCacheManager.rebuildCache(); }); clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, it -> { - QuantizationStateCacheManager.getInstance().rebuildCache(); + quantizationStateCacheManager.rebuildCache(); }); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index fe4d39c2cb..d795cb712e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -24,7 +24,7 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; import org.opensearch.common.UUIDs; -import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; @@ -49,30 +49,7 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException { this.segmentReadState = state; - fieldToUniqueCacheId = new HashMap<>(); - Map stateMap = KNNQuantizationStateReader.read(segmentReadState); - for (Map.Entry entry : stateMap.entrySet()) { - FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(entry.getKey()); - QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); - if (quantizationParams instanceof ScalarQuantizationParams) { - QuantizationState quantizationState; - ScalarQuantizationParams scalarQuantizationParams = (ScalarQuantizationParams) quantizationParams; - switch (scalarQuantizationParams.getSqType()) { - case ONE_BIT: - quantizationState = OneBitScalarQuantizationState.fromByteArray(entry.getValue()); - break; - case TWO_BIT: - case FOUR_BIT: - quantizationState = MultiBitScalarQuantizationState.fromByteArray(entry.getValue()); - break; - default: - throw new IllegalArgumentException("Unknown Scalar Quantization Type"); - } - String cacheKey = UUIDs.base64UUID(); - fieldToUniqueCacheId.put(entry.getKey(), cacheKey); - QuantizationStateCacheManager.getInstance().addQuantizationState(cacheKey, quantizationState); - } - } + populateFieldMapAndQuantizationStateCache(); this.flatVectorsReader = flatVectorsReader; } @@ -215,4 +192,31 @@ public void close() throws IOException { public long ramBytesUsed() { return flatVectorsReader.ramBytesUsed(); } + + private void populateFieldMapAndQuantizationStateCache() throws IOException { + fieldToUniqueCacheId = new HashMap<>(); + Map stateMap = KNNQuantizationStateReader.read(segmentReadState); + for (Map.Entry entry : stateMap.entrySet()) { + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(entry.getKey()); + QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); + if (quantizationParams instanceof ScalarQuantizationParams) { + QuantizationState quantizationState; + ScalarQuantizationParams scalarQuantizationParams = (ScalarQuantizationParams) quantizationParams; + switch (scalarQuantizationParams.getSqType()) { + case ONE_BIT: + quantizationState = OneBitScalarQuantizationState.fromByteArray(entry.getValue()); + break; + case TWO_BIT: + case FOUR_BIT: + quantizationState = MultiBitScalarQuantizationState.fromByteArray(entry.getValue()); + break; + default: + throw new IllegalArgumentException("Unknown Scalar Quantization Type"); + } + String cacheKey = UUIDs.base64UUID(); + fieldToUniqueCacheId.put(entry.getKey(), cacheKey); + QuantizationStateCacheManager.getInstance().addQuantizationState(cacheKey, quantizationState); + } + } + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index b1caeed1e6..9773e21492 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -34,7 +34,7 @@ import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; From 80e74e3535f00ab7bc2c18b0e2739a8a46d13c20 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Sep 2024 11:30:48 -0700 Subject: [PATCH 27/41] Abstract common functionality between read methods Signed-off-by: Ryan Bogan --- .../KNNQuantizationStateReader.java | 34 ++++++------ .../KNNQuantizationStateReaderTests.java | 54 +++++++++++++++++-- 2 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java index 90afc217b1..b4013d11ff 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java @@ -50,11 +50,7 @@ public final class KNNQuantizationStateReader { * @param state the read state to read from */ public static Map read(SegmentReadState state) throws IOException { - String quantizationStateFileName = IndexFileNames.segmentFileName( - state.segmentInfo.name, - state.segmentSuffix, - KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX - ); + String quantizationStateFileName = getQuantizationStateFileName(state); Map readQuantizationStateInfos = new HashMap<>(); try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { @@ -76,9 +72,7 @@ public static Map read(SegmentReadState state) throws IOExceptio } // Read each field's bytes for (int i = 0; i < numFields; i++) { - input.seek(positions.get(i)); - byte[] stateBytes = new byte[lengths.get(i)]; - input.readBytes(stateBytes, 0, lengths.get(i)); + byte[] stateBytes = readStateBytes(input, positions.get(i), lengths.get(i)); String fieldName = state.fieldInfos.fieldInfo(fieldNumbers.get(i)).getName(); readQuantizationStateInfos.put(fieldName, stateBytes); } @@ -94,11 +88,7 @@ public static Map read(SegmentReadState state) throws IOExceptio public static QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { SegmentReadState segmentReadState = readConfig.getSegmentReadState(); String field = readConfig.getField(); - String quantizationStateFileName = IndexFileNames.segmentFileName( - segmentReadState.segmentInfo.name, - segmentReadState.segmentSuffix, - KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX - ); + String quantizationStateFileName = getQuantizationStateFileName(segmentReadState); int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) { @@ -124,9 +114,8 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr throw new IllegalArgumentException(String.format("Field %s not found", field)); } - input.seek(position); - byte[] stateBytes = new byte[length]; - input.readBytes(stateBytes, 0, length); + byte[] stateBytes = readStateBytes(input, position, length); + // Deserialize the byte array to a quantization state object ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType(); if (scalarQuantizationType == ScalarQuantizationType.ONE_BIT) { @@ -150,4 +139,17 @@ static int getNumFields(IndexInput input) throws IOException { input.seek(indexStartPosition); return input.readInt(); } + + @VisibleForTesting + static byte[] readStateBytes(IndexInput input, long position, int length) throws IOException { + input.seek(position); + byte[] stateBytes = new byte[length]; + input.readBytes(stateBytes, 0, length); + return stateBytes; + } + + @VisibleForTesting + static String getQuantizationStateFileName(SegmentReadState state) { + return IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java index d3c2f75d1f..9a1426859d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.search.Sort; @@ -20,6 +21,7 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; @@ -84,8 +86,6 @@ public void testReadFromSegmentReadState() { mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); Mockito.verify(input, times(4)).readInt(); Mockito.verify(input, times(2)).readVLong(); - Mockito.verify(input, times(2)).readBytes(any(byte[].class), anyInt(), anyInt()); - Mockito.verify(input, times(2)).seek(anyLong()); } } } @@ -138,14 +138,14 @@ public void testReadFromQuantizationStateReadConfig() { try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { mockedStaticReader.when(() -> KNNQuantizationStateReader.getNumFields(input)).thenReturn(2); mockedStaticReader.when(() -> KNNQuantizationStateReader.read(quantizationStateReadConfig)).thenCallRealMethod(); + mockedStaticReader.when(() -> KNNQuantizationStateReader.readStateBytes(any(IndexInput.class), anyLong(), anyInt())) + .thenReturn(new byte[8]); try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { assertThrows(IllegalArgumentException.class, () -> KNNQuantizationStateReader.read(quantizationStateReadConfig)); mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); Mockito.verify(input, times(4)).readInt(); Mockito.verify(input, times(2)).readVLong(); - Mockito.verify(input, times(0)).readBytes(any(byte[].class), anyInt(), anyInt()); - Mockito.verify(input, times(0)).seek(anyLong()); Mockito.when(input.readInt()).thenReturn(fieldNumber); @@ -187,4 +187,50 @@ public void testGetNumFields() { Mockito.verify(input, times(2)).seek(anyLong()); Mockito.verify(input, times(1)).length(); } + + @SneakyThrows + public void testReadStateBytes() { + IndexInput input = Mockito.mock(IndexInput.class); + long position = 1; + int length = 2; + byte[] stateBytes = new byte[length]; + KNNQuantizationStateReader.readStateBytes(input, position, length); + + Mockito.verify(input, times(1)).seek(position); + Mockito.verify(input, times(1)).readBytes(stateBytes, 0, length); + + } + + @SneakyThrows + public void testGetQuantizationStateFileName() { + String segmentName = "test-segment"; + String segmentSuffix = "test-suffix"; + String expectedName = IndexFileNames.segmentFileName(segmentName, segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + final SegmentReadState segmentReadState = new SegmentReadState( + Mockito.mock(Directory.class), + segmentInfo, + Mockito.mock(FieldInfos.class), + Mockito.mock(IOContext.class), + segmentSuffix + ); + + assertEquals(expectedName, KNNQuantizationStateReader.getQuantizationStateFileName(segmentReadState)); + + } } From 60cd2fa80298054ccd36708dfda86986f87b4ab9 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Sep 2024 11:44:26 -0700 Subject: [PATCH 28/41] Avoid repeat calls to quantization cache manager get instance Signed-off-by: Ryan Bogan --- .../codec/KNN990Codec/NativeEngines990KnnVectorsReader.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index d795cb712e..d25c8db19d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -45,6 +45,7 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { private final FlatVectorsReader flatVectorsReader; private final SegmentReadState segmentReadState; + private final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance(); private Map fieldToUniqueCacheId; public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException { @@ -215,7 +216,7 @@ private void populateFieldMapAndQuantizationStateCache() throws IOException { } String cacheKey = UUIDs.base64UUID(); fieldToUniqueCacheId.put(entry.getKey(), cacheKey); - QuantizationStateCacheManager.getInstance().addQuantizationState(cacheKey, quantizationState); + quantizationStateCacheManager.addQuantizationState(cacheKey, quantizationState); } } } From c0b9e71ec78c5ec7adf5bcc1d4b3ded177675977 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Sep 2024 16:00:17 -0700 Subject: [PATCH 29/41] Address PR feedback Signed-off-by: Ryan Bogan --- .../org/opensearch/knn/common/KNNConstants.java | 2 +- .../KNN990Codec/KNNQuantizationStateWriter.java | 2 +- .../NativeEngines990KnnVectorsReader.java | 1 + .../QuantizationConfigKNNCollector.java | 3 +++ .../org/opensearch/knn/index/query/KNNWeight.java | 14 +++++++++----- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index aa9ca01ca6..3e760b1829 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -73,7 +73,7 @@ public class KNNConstants { public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; public static final String RADIAL_SEARCH_KEY = "radial_search"; - public static final String QUANTIZATION_STATE_FILE_SUFFIX = "qstate"; + public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate"; // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java index 03e392d7ba..32f38e0cad 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java @@ -61,7 +61,7 @@ public KNNQuantizationStateWriter(SegmentWriteState segmentWriteState) throws IO * @throws IOException exception could be thrown while writing header */ public void writeHeader(SegmentWriteState segmentWriteState) throws IOException { - CodecUtil.writeIndexHeader(output, "QuantizationCodec", 0, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix); + CodecUtil.writeIndexHeader(output, "KNN990Codec", 0, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index d25c8db19d..89f5aea972 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -117,6 +117,7 @@ public ByteVectorValues getByteVectorValues(final String field) throws IOExcepti */ @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + // TODO: This is a temporary hack where we are using KNNCollector to initialize the quantization state. if (knnCollector instanceof QuantizationConfigKNNCollector) { String cacheKey = fieldToUniqueCacheId.get(field); FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(field); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java index 6dabac4a41..315206e33e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java @@ -11,6 +11,9 @@ import org.apache.lucene.search.TopDocs; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +/** + * Collector used for passing the quantization state during query flow. + */ @Setter @Getter public class QuantizationConfigKNNCollector implements KnnCollector { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 9773e21492..8c24df569a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -8,7 +8,6 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; import org.apache.lucene.search.DocIdSetIterator; @@ -75,7 +74,7 @@ public class KNNWeight extends Weight { private final ExactSearcher exactSearcher; private static ExactSearcher DEFAULT_EXACT_SEARCHER; - private final QuantizationService quantizationService = QuantizationService.getInstance(); + private final QuantizationService quantizationService; public KNNWeight(KNNQuery query, float boost) { super(query); @@ -84,6 +83,7 @@ public KNNWeight(KNNQuery query, float boost) { this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); this.filterWeight = null; this.exactSearcher = DEFAULT_EXACT_SEARCHER; + this.quantizationService = QuantizationService.getInstance(); } public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { @@ -93,6 +93,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); this.filterWeight = filterWeight; this.exactSearcher = DEFAULT_EXACT_SEARCHER; + this.quantizationService = QuantizationService.getInstance(); } public static void initialize(ModelDao modelDao) { @@ -221,8 +222,7 @@ private Map doANNSearch( final int cardinality, final int k ) throws IOException { - LeafReader reader2 = context.reader(); - final SegmentReader reader = Lucene.segmentReader(reader2); + final SegmentReader reader = Lucene.segmentReader(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); @@ -260,15 +260,17 @@ private Map doANNSearch( QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); + // TODO: Change type of vector once more quantization methods are supported byte[] quantizedVector = null; if (quantizationParams != null) { QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); reader.searchNearestVectors(knnQuery.getField(), new float[0], tempCollector, null); if (tempCollector.getQuantizationState() == null) { - throw new IllegalStateException("No quantization state for file"); + throw new IllegalStateException(String.format("No quantization state found for field %s", fieldInfo.getName())); } QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams); + // TODO: In the future, byte array will not be the only output type from this method quantizedVector = (byte[]) quantizationService.quantize( tempCollector.getQuantizationState(), knnQuery.getQueryVector(), @@ -297,6 +299,7 @@ private Map doANNSearch( spaceType, knnEngine, knnQuery.getIndexName(), + // TODO: In the future, more vector data types will be supported with quantization quantizationParams == null ? vectorDataType : VectorDataType.BINARY ), knnQuery.getIndexName(), @@ -325,6 +328,7 @@ private Map doANNSearch( || quantizationParams != null && quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) { results = JNIService.queryBinaryIndex( indexAllocation.getMemoryAddress(), + // TODO: In the future, quantizedVector can have other data types than byte quantizationParams == null ? knnQuery.getByteQueryVector() : quantizedVector, k, knnQuery.getMethodParameters(), From 8c213043b6f68061ac1556e26c0cee9c771b16b6 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Sep 2024 18:28:23 -0700 Subject: [PATCH 30/41] Add unit tests for KNNWeight Signed-off-by: Ryan Bogan --- .../knn/index/query/KNNWeightTests.java | 161 ++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index aeb419d60d..5e5c5fa11f 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -32,6 +32,7 @@ import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; +import org.mockito.MockedConstruction; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.common.io.PathUtils; @@ -39,6 +40,7 @@ import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -47,6 +49,7 @@ import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; @@ -54,6 +57,11 @@ import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.nio.file.Path; @@ -1350,4 +1358,157 @@ private KNNQueryResult[] getFilteredKNNQueryResults() { .collect(Collectors.toList()) .toArray(new KNNQueryResult[0]); } + + @SneakyThrows + public void testANNWithQuantizationParams_whenStateNotFound_thenFail() { + try (MockedStatic quantizationServiceMockedStatic = Mockito.mockStatic(QuantizationService.class)) { + QuantizationService quantizationService = Mockito.mock(QuantizationService.class); + QuantizationParams quantizationParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + Mockito.when(quantizationService.getQuantizationParams(any(FieldInfo.class))).thenReturn(quantizationParams); + quantizationServiceMockedStatic.when(QuantizationService::getInstance).thenReturn(quantizationService); + + // Given + int k = 3; + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + jniServiceMockedStatic.when( + () -> JNIService.queryBinaryIndex( + anyLong(), + eq(BYTE_QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + final SegmentReader reader = mockSegmentReader(); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .build(); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + expectThrows(IllegalStateException.class, () -> knnWeight.scorer(leafReaderContext)); + } + } + + @SneakyThrows + public void testANN() { + try (MockedStatic quantizationServiceMockedStatic = Mockito.mockStatic(QuantizationService.class)) { + QuantizationService quantizationService = Mockito.mock(QuantizationService.class); + ScalarQuantizationParams quantizationParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + Mockito.when(quantizationService.getQuantizationParams(any(FieldInfo.class))).thenReturn(quantizationParams); + quantizationServiceMockedStatic.when(QuantizationService::getInstance).thenReturn(quantizationService); + + float[] meanThresholds = new float[] { 1.2f, 2.3f, 3.4f, 4.5f }; + QuantizationState quantizationState = new OneBitScalarQuantizationState(quantizationParams, meanThresholds); + + try ( + MockedConstruction quantizationCollectorMockedConstruction = Mockito.mockConstruction( + QuantizationConfigKNNCollector.class, + (mock, context) -> Mockito.when(mock.getQuantizationState()).thenReturn(quantizationState) + ) + ) { + + // Given + int k = 3; + jniServiceMockedStatic.when( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + + jniServiceMockedStatic.when( + () -> JNIService.queryBinaryIndex( + anyLong(), + eq(BYTE_QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + final SegmentReader reader = mockSegmentReader(); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .build(); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + assertNotNull(knnScorer); + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ), + times(1) + ); + } + } + } } From b304e3c25a1cebbb013bd44c82064257a825813e Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Sep 2024 22:11:21 -0700 Subject: [PATCH 31/41] Address PR Feedback Signed-off-by: Ryan Bogan --- ...ava => KNN990QuantizationStateReader.java} | 16 ++++----- ...ava => KNN990QuantizationStateWriter.java} | 12 +++++-- .../NativeEngines990KnnVectorsReader.java | 16 ++++----- .../NativeEngines990KnnVectorsWriter.java | 4 +-- .../opensearch/knn/index/query/KNNWeight.java | 36 ++++++++++--------- .../QuantizationStateCacheManager.java | 4 +-- ...> KNN990QuantizationStateReaderTests.java} | 34 +++++++++--------- ...> KNN990QuantizationStateWriterTests.java} | 10 +++--- ...NativeEngines990KnnVectorsFormatTests.java | 4 --- .../knn/index/query/KNNWeightTests.java | 2 +- .../QuantizationStateCacheManagerTests.java | 7 ++-- 11 files changed, 76 insertions(+), 69 deletions(-) rename src/main/java/org/opensearch/knn/index/codec/KNN990Codec/{KNNQuantizationStateReader.java => KNN990QuantizationStateReader.java} (93%) rename src/main/java/org/opensearch/knn/index/codec/KNN990Codec/{KNNQuantizationStateWriter.java => KNN990QuantizationStateWriter.java} (90%) rename src/test/java/org/opensearch/knn/index/codec/KNN990Codec/{KNNQuantizationStateReaderTests.java => KNN990QuantizationStateReaderTests.java} (84%) rename src/test/java/org/opensearch/knn/index/codec/KNN990Codec/{KNNQuantizationStateWriterTests.java => KNN990QuantizationStateWriterTests.java} (93%) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java similarity index 93% rename from src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java rename to src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index b4013d11ff..11a34f02b2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -28,7 +28,7 @@ /** * Reads quantization states */ -public final class KNNQuantizationStateReader { +public final class KNN990QuantizationStateReader { /** * Read quantization states and return list of fieldNames and bytes @@ -118,14 +118,15 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr // Deserialize the byte array to a quantization state object ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType(); - if (scalarQuantizationType == ScalarQuantizationType.ONE_BIT) { - return OneBitScalarQuantizationState.fromByteArray(stateBytes); - } else if (scalarQuantizationType == ScalarQuantizationType.TWO_BIT - || scalarQuantizationType == ScalarQuantizationType.FOUR_BIT) { + switch (scalarQuantizationType) { + case ONE_BIT: + return OneBitScalarQuantizationState.fromByteArray(stateBytes); + case TWO_BIT: + case FOUR_BIT: return MultiBitScalarQuantizationState.fromByteArray(stateBytes); - } else { + default: throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); - } + } } } @@ -135,7 +136,6 @@ static int getNumFields(IndexInput input) throws IOException { long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; input.seek(markerAndIndexPosition); long indexStartPosition = input.readLong(); - input.readInt(); input.seek(indexStartPosition); return input.readInt(); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java similarity index 90% rename from src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java rename to src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java index 32f38e0cad..a0cd16e4f7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java @@ -21,7 +21,7 @@ /** * Writes quantization states to off heap memory */ -public final class KNNQuantizationStateWriter { +public final class KNN990QuantizationStateWriter { private final IndexOutput output; private List fieldQuantizationStates = new ArrayList<>(); @@ -45,7 +45,7 @@ public final class KNNQuantizationStateWriter { * @param segmentWriteState segment write state containing segment information * @throws IOException exception could be thrown while creating the output */ - public KNNQuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException { + public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException { String quantizationStateFileName = IndexFileNames.segmentFileName( segmentWriteState.segmentInfo.name, segmentWriteState.segmentSuffix, @@ -61,7 +61,13 @@ public KNNQuantizationStateWriter(SegmentWriteState segmentWriteState) throws IO * @throws IOException exception could be thrown while writing header */ public void writeHeader(SegmentWriteState segmentWriteState) throws IOException { - CodecUtil.writeIndexHeader(output, "KNN990Codec", 0, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix); + CodecUtil.writeIndexHeader( + output, + "NativeEngines99KnnVectorsFormatQSData", + 0, + segmentWriteState.segmentInfo.getId(), + segmentWriteState.segmentSuffix + ); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 89f5aea972..06f705c1fa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -46,11 +46,11 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { private final FlatVectorsReader flatVectorsReader; private final SegmentReadState segmentReadState; private final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance(); - private Map fieldToUniqueCacheId; + private Map quantizationStateCacheKeyPerField; public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException { this.segmentReadState = state; - populateFieldMapAndQuantizationStateCache(); + primeQuantizationStateCache(); this.flatVectorsReader = flatVectorsReader; } @@ -119,7 +119,7 @@ public ByteVectorValues getByteVectorValues(final String field) throws IOExcepti public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { // TODO: This is a temporary hack where we are using KNNCollector to initialize the quantization state. if (knnCollector instanceof QuantizationConfigKNNCollector) { - String cacheKey = fieldToUniqueCacheId.get(field); + String cacheKey = quantizationStateCacheKeyPerField.get(field); FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(field); QuantizationState quantizationState = QuantizationStateCacheManager.getInstance() .getQuantizationState( @@ -182,7 +182,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits @Override public void close() throws IOException { IOUtils.close(flatVectorsReader); - for (String cacheKey : fieldToUniqueCacheId.values()) { + for (String cacheKey : quantizationStateCacheKeyPerField.values()) { QuantizationStateCacheManager.getInstance().evict(cacheKey); } } @@ -195,9 +195,9 @@ public long ramBytesUsed() { return flatVectorsReader.ramBytesUsed(); } - private void populateFieldMapAndQuantizationStateCache() throws IOException { - fieldToUniqueCacheId = new HashMap<>(); - Map stateMap = KNNQuantizationStateReader.read(segmentReadState); + private void primeQuantizationStateCache() throws IOException { + quantizationStateCacheKeyPerField = new HashMap<>(); + Map stateMap = KNN990QuantizationStateReader.read(segmentReadState); for (Map.Entry entry : stateMap.entrySet()) { FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(entry.getKey()); QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); @@ -216,7 +216,7 @@ private void populateFieldMapAndQuantizationStateCache() throws IOException { throw new IllegalArgumentException("Unknown Scalar Quantization Type"); } String cacheKey = UUIDs.base64UUID(); - fieldToUniqueCacheId.put(entry.getKey(), cacheKey); + quantizationStateCacheKeyPerField.put(entry.getKey(), cacheKey); quantizationStateCacheManager.addQuantizationState(cacheKey, quantizationState); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index dd16de5654..664cd7f007 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -51,7 +51,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; - private final KNNQuantizationStateWriter quantizationStateWriter; + private final KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; private final QuantizationService quantizationService = QuantizationService.getInstance(); @@ -59,7 +59,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) throws IOException { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; - this.quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + this.quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); quantizationStateWriter.writeHeader(segmentWriteState); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 8c24df569a..1769328fe6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -261,22 +261,7 @@ private Map doANNSearch( QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); // TODO: Change type of vector once more quantization methods are supported - byte[] quantizedVector = null; - - if (quantizationParams != null) { - QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); - reader.searchNearestVectors(knnQuery.getField(), new float[0], tempCollector, null); - if (tempCollector.getQuantizationState() == null) { - throw new IllegalStateException(String.format("No quantization state found for field %s", fieldInfo.getName())); - } - QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams); - // TODO: In the future, byte array will not be the only output type from this method - quantizedVector = (byte[]) quantizationService.quantize( - tempCollector.getQuantizationState(), - knnQuery.getQueryVector(), - quantizationOutput - ); - } + byte[] quantizedVector = getQuantizedVector(quantizationParams, reader, fieldInfo); List engineFiles = getEngineFiles(reader, knnEngine.getExtension()); if (engineFiles.isEmpty()) { @@ -477,4 +462,23 @@ private boolean isExactSearchThresholdSettingSet(int filterThresholdValue) { private boolean canDoExactSearchAfterANNSearch(final int filterIdsCount, final int annResultCount) { return filterWeight != null && filterIdsCount >= knnQuery.getK() && knnQuery.getK() > annResultCount; } + + // TODO: this will eventually return more types than just byte + private byte[] getQuantizedVector(QuantizationParams quantizationParams, SegmentReader reader, FieldInfo fieldInfo) throws IOException { + if (quantizationParams != null) { + QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); + reader.searchNearestVectors(knnQuery.getField(), new float[0], tempCollector, null); + if (tempCollector.getQuantizationState() == null) { + throw new IllegalStateException(String.format("No quantization state found for field %s", fieldInfo.getName())); + } + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams); + // TODO: In the future, byte array will not be the only output type from this method + return (byte[]) quantizationService.quantize( + tempCollector.getQuantizationState(), + knnQuery.getQueryVector(), + quantizationOutput + ); + } + return null; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index 5873153134..21fde612fb 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -7,7 +7,7 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; -import org.opensearch.knn.index.codec.KNN990Codec.KNNQuantizationStateReader; +import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateReader; import java.io.IOException; @@ -44,7 +44,7 @@ public QuantizationState getQuantizationState(QuantizationStateReadConfig quanti QuantizationState quantizationState = QuantizationStateCache.getInstance() .getQuantizationState(quantizationStateReadConfig.getCacheKey()); if (quantizationState == null) { - quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); + quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java similarity index 84% rename from src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java rename to src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java index 9a1426859d..0f73337144 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java @@ -37,7 +37,7 @@ import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; -public class KNNQuantizationStateReaderTests extends KNNTestCase { +public class KNN990QuantizationStateReaderTests extends KNNTestCase { @SneakyThrows public void testReadFromSegmentReadState() { @@ -77,11 +77,11 @@ public void testReadFromSegmentReadState() { segmentSuffix ); - try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { - mockedStaticReader.when(() -> KNNQuantizationStateReader.getNumFields(input)).thenReturn(2); - mockedStaticReader.when(() -> KNNQuantizationStateReader.read(segmentReadState)).thenCallRealMethod(); + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNN990QuantizationStateReader.getNumFields(input)).thenReturn(2); + mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(segmentReadState)).thenCallRealMethod(); try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { - KNNQuantizationStateReader.read(segmentReadState); + KNN990QuantizationStateReader.read(segmentReadState); mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); Mockito.verify(input, times(4)).readInt(); @@ -135,13 +135,13 @@ public void testReadFromQuantizationStateReadConfig() { Mockito.when(quantizationStateReadConfig.getSegmentReadState()).thenReturn(segmentReadState); Mockito.when(quantizationStateReadConfig.getField()).thenReturn(fieldName); - try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { - mockedStaticReader.when(() -> KNNQuantizationStateReader.getNumFields(input)).thenReturn(2); - mockedStaticReader.when(() -> KNNQuantizationStateReader.read(quantizationStateReadConfig)).thenCallRealMethod(); - mockedStaticReader.when(() -> KNNQuantizationStateReader.readStateBytes(any(IndexInput.class), anyLong(), anyInt())) + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNN990QuantizationStateReader.getNumFields(input)).thenReturn(2); + mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(quantizationStateReadConfig)).thenCallRealMethod(); + mockedStaticReader.when(() -> KNN990QuantizationStateReader.readStateBytes(any(IndexInput.class), anyLong(), anyInt())) .thenReturn(new byte[8]); try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { - assertThrows(IllegalArgumentException.class, () -> KNNQuantizationStateReader.read(quantizationStateReadConfig)); + assertThrows(IllegalArgumentException.class, () -> KNN990QuantizationStateReader.read(quantizationStateReadConfig)); mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); Mockito.verify(input, times(4)).readInt(); @@ -154,7 +154,7 @@ public void testReadFromQuantizationStateReadConfig() { OneBitScalarQuantizationState oneBitScalarQuantizationState = Mockito.mock(OneBitScalarQuantizationState.class); mockedStaticOneBit.when(() -> OneBitScalarQuantizationState.fromByteArray(any(byte[].class))) .thenReturn(oneBitScalarQuantizationState); - QuantizationState quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); + QuantizationState quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); assertTrue(quantizationState instanceof OneBitScalarQuantizationState); } @@ -165,12 +165,12 @@ public void testReadFromQuantizationStateReadConfig() { Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams2); Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams2); - QuantizationState quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); + QuantizationState quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams4); Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams4); - quantizationState = KNNQuantizationStateReader.read(quantizationStateReadConfig); + quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); } } @@ -180,9 +180,9 @@ public void testReadFromQuantizationStateReadConfig() { @SneakyThrows public void testGetNumFields() { IndexInput input = Mockito.mock(IndexInput.class); - KNNQuantizationStateReader.getNumFields(input); + KNN990QuantizationStateReader.getNumFields(input); - Mockito.verify(input, times(2)).readInt(); + Mockito.verify(input, times(1)).readInt(); Mockito.verify(input, times(1)).readLong(); Mockito.verify(input, times(2)).seek(anyLong()); Mockito.verify(input, times(1)).length(); @@ -194,7 +194,7 @@ public void testReadStateBytes() { long position = 1; int length = 2; byte[] stateBytes = new byte[length]; - KNNQuantizationStateReader.readStateBytes(input, position, length); + KNN990QuantizationStateReader.readStateBytes(input, position, length); Mockito.verify(input, times(1)).seek(position); Mockito.verify(input, times(1)).readBytes(stateBytes, 0, length); @@ -230,7 +230,7 @@ public void testGetQuantizationStateFileName() { segmentSuffix ); - assertEquals(expectedName, KNNQuantizationStateReader.getQuantizationStateFileName(segmentReadState)); + assertEquals(expectedName, KNN990QuantizationStateReader.getQuantizationStateFileName(segmentReadState)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java similarity index 93% rename from src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java rename to src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java index 4f018d1add..9664bca392 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNNQuantizationStateWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java @@ -35,7 +35,7 @@ import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; -public class KNNQuantizationStateWriterTests extends KNNTestCase { +public class KNN990QuantizationStateWriterTests extends KNNTestCase { @SneakyThrows public void testWriteHeader() { @@ -68,7 +68,7 @@ public void testWriteHeader() { null, Mockito.mock(IOContext.class) ); - KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + KNN990QuantizationStateWriter quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); try (MockedStatic mockedStaticCodecUtil = Mockito.mockStatic(CodecUtil.class)) { mockedStaticCodecUtil.when( () -> CodecUtil.writeIndexHeader(any(IndexOutput.class), anyString(), anyInt(), any(byte[].class), anyString()) @@ -77,7 +77,7 @@ public void testWriteHeader() { mockedStaticCodecUtil.verify( () -> CodecUtil.writeIndexHeader( output, - "QuantizationCodec", + "NativeEngines99KnnVectorsFormatQSData", 0, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix @@ -117,7 +117,7 @@ public void testWriteState() { null, Mockito.mock(IOContext.class) ); - KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + KNN990QuantizationStateWriter quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); int fieldNumber = 0; QuantizationState quantizationState = new OneBitScalarQuantizationState( @@ -160,7 +160,7 @@ public void testWriteFooter() { null, Mockito.mock(IOContext.class) ); - KNNQuantizationStateWriter quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState); + KNN990QuantizationStateWriter quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); int fieldNumber1 = 1; int fieldNumber2 = 2; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index b44c214063..8e524f3596 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -63,7 +63,6 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.io.IOException; @@ -204,8 +203,6 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc indexWriter.commit(); indexWriter.close(); - assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); - // Validate to see if correct values are returned, assumption here is only 1 segment is getting created IndexSearcher searcher = new IndexSearcher(indexReader); final LeafReader leafReader = searcher.getLeafContexts().get(0).reader(); @@ -275,7 +272,6 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce indexWriter.flush(); indexWriter.commit(); indexWriter.close(); - assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); IndexSearcher searcher = new IndexSearcher(indexReader); final LeafReader leafReader = searcher.getLeafContexts().get(0).reader(); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 5e5c5fa11f..a2b41804a5 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -1419,7 +1419,7 @@ public void testANNWithQuantizationParams_whenStateNotFound_thenFail() { } @SneakyThrows - public void testANN() { + public void testANNWithQuantizationParams_thenSuccess() { try (MockedStatic quantizationServiceMockedStatic = Mockito.mockStatic(QuantizationService.class)) { QuantizationService quantizationService = Mockito.mock(QuantizationService.class); ScalarQuantizationParams quantizationParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java index d1d3c329ad..14e55e627d 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java @@ -9,7 +9,7 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.codec.KNN990Codec.KNNQuantizationStateReader; +import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateReader; import static org.mockito.Mockito.times; @@ -36,8 +36,9 @@ public void testGetQuantizationState() { QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); Mockito.doNothing().when(quantizationStateCache).addQuantizationState(cacheKey, quantizationState); - try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNNQuantizationStateReader.class)) { - mockedStaticReader.when(() -> KNNQuantizationStateReader.read(quantizationStateReadConfig)).thenReturn(quantizationState); + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(quantizationStateReadConfig)) + .thenReturn(quantizationState); QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(cacheKey, quantizationState); } From d8959d026afb2722928f3c5e77df0fe92467ae3e Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 08:48:08 -0700 Subject: [PATCH 32/41] Address feedbackK Signed-off-by: Ryan Bogan --- .../KNN990Codec/KNN990QuantizationStateReaderTests.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java index 0f73337144..2801560622 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java @@ -155,7 +155,7 @@ public void testReadFromQuantizationStateReadConfig() { mockedStaticOneBit.when(() -> OneBitScalarQuantizationState.fromByteArray(any(byte[].class))) .thenReturn(oneBitScalarQuantizationState); QuantizationState quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); - assertTrue(quantizationState instanceof OneBitScalarQuantizationState); + assertEquals(oneBitScalarQuantizationState, quantizationState); } try (MockedStatic mockedStaticOneBit = mockStatic(MultiBitScalarQuantizationState.class)) { @@ -166,12 +166,12 @@ public void testReadFromQuantizationStateReadConfig() { Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams2); Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams2); QuantizationState quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); - assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); + assertEquals(multiBitScalarQuantizationState, quantizationState); Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams4); Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams4); quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); - assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); + assertEquals(multiBitScalarQuantizationState, quantizationState); } } } From ba931da8fdea4a52ce7b1344021f3d83751bdb4f Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 09:22:47 -0700 Subject: [PATCH 33/41] Fix bwc tests Signed-off-by: Ryan Bogan --- .../codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index 626210f250..6582cdd1f0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -29,7 +29,7 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { /** The format for storing, reading, merging vectors on disk */ private static FlatVectorsFormat flatVectorsFormat; - private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; + private static final String FORMAT_NAME = "NativeEngines99KnnVectorsFormat"; public NativeEngines990KnnVectorsFormat() { super(FORMAT_NAME); From fdfc301b0408008044b0b8d668e1c1b5750b7b08 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 09:38:34 -0700 Subject: [PATCH 34/41] Revert previous change Signed-off-by: Ryan Bogan --- .../codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index 6582cdd1f0..626210f250 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -29,7 +29,7 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { /** The format for storing, reading, merging vectors on disk */ private static FlatVectorsFormat flatVectorsFormat; - private static final String FORMAT_NAME = "NativeEngines99KnnVectorsFormat"; + private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; public NativeEngines990KnnVectorsFormat() { super(FORMAT_NAME); From 0e37d2da8c7a1c2c444203d99bd2707f5e9f7f86 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 10:26:09 -0700 Subject: [PATCH 35/41] Condense into one loop while reading Signed-off-by: Ryan Bogan --- .../KNN990QuantizationStateReader.java | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index 11a34f02b2..68ef727f92 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -20,9 +20,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; /** @@ -58,22 +56,13 @@ public static Map read(SegmentReadState state) throws IOExceptio int numFields = getNumFields(input); - List fieldNumbers = new ArrayList<>(); - List positions = new ArrayList<>(); - List lengths = new ArrayList<>(); - - // Read each field's metadata from the index section + // Read each field's metadata from the index section and then read bytes for (int i = 0; i < numFields; i++) { - fieldNumbers.add(input.readInt()); + int fieldNumber = input.readInt(); int length = input.readInt(); - lengths.add(length); long position = input.readVLong(); - positions.add(position); - } - // Read each field's bytes - for (int i = 0; i < numFields; i++) { - byte[] stateBytes = readStateBytes(input, positions.get(i), lengths.get(i)); - String fieldName = state.fieldInfos.fieldInfo(fieldNumbers.get(i)).getName(); + byte[] stateBytes = readStateBytes(input, position, length); + String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName(); readQuantizationStateInfos.put(fieldName, stateBytes); } } From eaff8f068801b7ad4139825a249909f8a0e81c2c Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 12:57:28 -0700 Subject: [PATCH 36/41] Address PR Feedback Signed-off-by: Ryan Bogan --- .../KNN990QuantizationStateWriter.java | 3 ++- .../NativeEngines990KnnVectorsWriter.java | 19 +++++++++++++++---- .../QuantizationConfigKNNCollector.java | 18 ++++++++++-------- .../KNN990QuantizationStateWriterTests.java | 2 +- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java index a0cd16e4f7..49b1819c10 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java @@ -25,6 +25,7 @@ public final class KNN990QuantizationStateWriter { private final IndexOutput output; private List fieldQuantizationStates = new ArrayList<>(); + static final String NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA = "NativeEngines990KnnVectorsFormatQSData"; /** * Constructor @@ -63,7 +64,7 @@ public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws public void writeHeader(SegmentWriteState segmentWriteState) throws IOException { CodecUtil.writeIndexHeader( output, - "NativeEngines99KnnVectorsFormatQSData", + NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA, 0, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 664cd7f007..5b4d903361 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -51,16 +51,15 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; - private final KNN990QuantizationStateWriter quantizationStateWriter; + private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; private final QuantizationService quantizationService = QuantizationService.getInstance(); - public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) throws IOException { + public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; - this.quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); - quantizationStateWriter.writeHeader(segmentWriteState); + this.quantizationStateWriter = null; } /** @@ -86,6 +85,8 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { flatVectorsWriter.flush(maxDoc, sortMap); + initQuantizationStateWriterIfNecessary(); + for (final NativeEngineFieldVectorsWriter field : fields) { trainAndIndex( field.getFieldInfo(), @@ -102,6 +103,9 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); + + initQuantizationStateWriterIfNecessary(); + // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs trainAndIndex( fieldInfo, @@ -262,4 +266,11 @@ private void trainAndIndex( graphBuildTime.incrementBy(time_in_millis); log.warn("Graph build took " + time_in_millis + " ms for " + operationName); } + + private void initQuantizationStateWriterIfNecessary() throws IOException { + if (quantizationStateWriter == null) { + quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); + quantizationStateWriter.writeHeader(segmentWriteState); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java index 315206e33e..295b0fe585 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java @@ -20,43 +20,45 @@ public class QuantizationConfigKNNCollector implements KnnCollector { private QuantizationState quantizationState; + private final String NATIVE_ENGINE_SEARCH_ERROR_MESSAGE = "Search functionality using codec is not supported with Native Engine Reader"; + @Override public boolean earlyTerminated() { - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); } @Override public void incVisitedCount(int i) { - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); } @Override public long visitedCount() { - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); } @Override public long visitLimit() { - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); } @Override public int k() { - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); } @Override public boolean collect(int i, float v) { - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); } @Override public float minCompetitiveSimilarity() { - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); } @Override public TopDocs topDocs() { - throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java index 9664bca392..2423a68277 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java @@ -77,7 +77,7 @@ public void testWriteHeader() { mockedStaticCodecUtil.verify( () -> CodecUtil.writeIndexHeader( output, - "NativeEngines99KnnVectorsFormatQSData", + KNN990QuantizationStateWriter.NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA, 0, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix From 414b2e4e897f0e2b2aaa1ffc93a909afcea49d4f Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 14:01:55 -0700 Subject: [PATCH 37/41] Address PR Feedback Signed-off-by: Ryan Bogan --- .../KNN990Codec/KNN990QuantizationStateReader.java | 2 ++ .../NativeEngines990KnnVectorsWriter.java | 14 +++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index 68ef727f92..9cbcd890bb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -65,6 +65,8 @@ public static Map read(SegmentReadState state) throws IOExceptio String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName(); readQuantizationStateInfos.put(fieldName, stateBytes); } + } catch (Exception e) { + return readQuantizationStateInfos; } return readQuantizationStateInfos; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 5b4d903361..af7f1c5765 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -59,7 +59,6 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; - this.quantizationStateWriter = null; } /** @@ -85,8 +84,6 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { flatVectorsWriter.flush(maxDoc, sortMap); - initQuantizationStateWriterIfNecessary(); - for (final NativeEngineFieldVectorsWriter field : fields) { trainAndIndex( field.getFieldInfo(), @@ -104,8 +101,6 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - initQuantizationStateWriterIfNecessary(); - // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs trainAndIndex( fieldInfo, @@ -126,7 +121,9 @@ public void finish() throws IOException { throw new IllegalStateException("NativeEnginesKNNVectorsWriter is already finished"); } finished = true; - quantizationStateWriter.writeFooter(); + if (quantizationStateWriter != null) { + quantizationStateWriter.writeFooter(); + } flatVectorsWriter.finish(); } @@ -145,7 +142,9 @@ public void finish() throws IOException { */ @Override public void close() throws IOException { - quantizationStateWriter.closeOutput(); + if (quantizationStateWriter != null) { + quantizationStateWriter.closeOutput(); + } IOUtils.close(flatVectorsWriter); } @@ -250,6 +249,7 @@ private void trainAndIndex( QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; if (quantizationParams != null) { + initQuantizationStateWriterIfNecessary(); quantizationState = quantizationService.train(quantizationParams, knnVectorValues); quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); } From 2accfb187d7cbcf8f835a21af57c7e2218f9b181 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 14:53:29 -0700 Subject: [PATCH 38/41] Address feedback Signed-off-by: Ryan Bogan --- .../KNN990Codec/KNN990QuantizationStateReader.java | 12 +++++++++--- .../NativeEngines990KnnVectorsReader.java | 3 +++ .../QuantizationStateCacheManager.java | 4 +++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index 9cbcd890bb..104b4fc3aa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -19,7 +19,10 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; +import java.io.FileNotFoundException; import java.io.IOException; +import java.nio.file.NoSuchFileException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -49,13 +52,14 @@ public final class KNN990QuantizationStateReader { */ public static Map read(SegmentReadState state) throws IOException { String quantizationStateFileName = getQuantizationStateFileName(state); - Map readQuantizationStateInfos = new HashMap<>(); try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { CodecUtil.retrieveChecksum(input); int numFields = getNumFields(input); + Map readQuantizationStateInfos = new HashMap<>(); + // Read each field's metadata from the index section and then read bytes for (int i = 0; i < numFields; i++) { int fieldNumber = input.readInt(); @@ -65,10 +69,10 @@ public static Map read(SegmentReadState state) throws IOExceptio String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName(); readQuantizationStateInfos.put(fieldName, stateBytes); } - } catch (Exception e) { return readQuantizationStateInfos; + } catch (FileNotFoundException | NoSuchFileException e) { + return Collections.emptyMap(); } - return readQuantizationStateInfos; } /** @@ -118,6 +122,8 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr default: throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); } + } catch (FileNotFoundException | NoSuchFileException e) { + return null; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 06f705c1fa..b22ed58796 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -130,6 +130,9 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits cacheKey ) ); + if (quantizationState == null) { + return; + } ((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(quantizationState); return; } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index 21fde612fb..c71d77ae6a 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -45,7 +45,9 @@ public QuantizationState getQuantizationState(QuantizationStateReadConfig quanti .getQuantizationState(quantizationStateReadConfig.getCacheKey()); if (quantizationState == null) { quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); - addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); + if (quantizationState != null) { + addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); + } } return quantizationState; From a6c87e3cb1a151c43fb433d643229e5c03a6d131 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 14:55:45 -0700 Subject: [PATCH 39/41] Revert "Address feedback" This reverts commit 2accfb187d7cbcf8f835a21af57c7e2218f9b181. Signed-off-by: Ryan Bogan --- .../KNN990Codec/KNN990QuantizationStateReader.java | 12 +++--------- .../NativeEngines990KnnVectorsReader.java | 3 --- .../QuantizationStateCacheManager.java | 4 +--- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index 104b4fc3aa..9cbcd890bb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -19,10 +19,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; -import java.io.FileNotFoundException; import java.io.IOException; -import java.nio.file.NoSuchFileException; -import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -52,14 +49,13 @@ public final class KNN990QuantizationStateReader { */ public static Map read(SegmentReadState state) throws IOException { String quantizationStateFileName = getQuantizationStateFileName(state); + Map readQuantizationStateInfos = new HashMap<>(); try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { CodecUtil.retrieveChecksum(input); int numFields = getNumFields(input); - Map readQuantizationStateInfos = new HashMap<>(); - // Read each field's metadata from the index section and then read bytes for (int i = 0; i < numFields; i++) { int fieldNumber = input.readInt(); @@ -69,10 +65,10 @@ public static Map read(SegmentReadState state) throws IOExceptio String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName(); readQuantizationStateInfos.put(fieldName, stateBytes); } + } catch (Exception e) { return readQuantizationStateInfos; - } catch (FileNotFoundException | NoSuchFileException e) { - return Collections.emptyMap(); } + return readQuantizationStateInfos; } /** @@ -122,8 +118,6 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr default: throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); } - } catch (FileNotFoundException | NoSuchFileException e) { - return null; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index b22ed58796..06f705c1fa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -130,9 +130,6 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits cacheKey ) ); - if (quantizationState == null) { - return; - } ((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(quantizationState); return; } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index c71d77ae6a..21fde612fb 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -45,9 +45,7 @@ public QuantizationState getQuantizationState(QuantizationStateReadConfig quanti .getQuantizationState(quantizationStateReadConfig.getCacheKey()); if (quantizationState == null) { quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); - if (quantizationState != null) { - addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); - } + addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); } return quantizationState; From ff07704796eb175c1bee70bf280df064ce664f7c Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 16:24:51 -0700 Subject: [PATCH 40/41] Address feedback Signed-off-by: Ryan Bogan --- .../KNN990Codec/KNN990QuantizationStateReader.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index 9cbcd890bb..c894769637 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.codec.KNN990Codec; import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; @@ -20,12 +21,14 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; /** * Reads quantization states */ +@Log4j2 public final class KNN990QuantizationStateReader { /** @@ -49,13 +52,15 @@ public final class KNN990QuantizationStateReader { */ public static Map read(SegmentReadState state) throws IOException { String quantizationStateFileName = getQuantizationStateFileName(state); - Map readQuantizationStateInfos = new HashMap<>(); + Map readQuantizationStateInfos = null; try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { CodecUtil.retrieveChecksum(input); int numFields = getNumFields(input); + readQuantizationStateInfos = new HashMap<>(); + // Read each field's metadata from the index section and then read bytes for (int i = 0; i < numFields; i++) { int fieldNumber = input.readInt(); @@ -66,7 +71,8 @@ public static Map read(SegmentReadState state) throws IOExceptio readQuantizationStateInfos.put(fieldName, stateBytes); } } catch (Exception e) { - return readQuantizationStateInfos; + log.error(e.getMessage()); + return Collections.emptyMap(); } return readQuantizationStateInfos; } From 34cd60b6e07d6db6315aa8b02a33d5d8d4a9eff4 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 4 Sep 2024 16:41:35 -0700 Subject: [PATCH 41/41] Address feedback Signed-off-by: Ryan Bogan --- .../codec/KNN990Codec/KNN990QuantizationStateReader.java | 5 ++++- .../quantizationState/QuantizationStateCacheManager.java | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index c894769637..5ae4e7b3b7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -71,7 +71,7 @@ public static Map read(SegmentReadState state) throws IOExceptio readQuantizationStateInfos.put(fieldName, stateBytes); } } catch (Exception e) { - log.error(e.getMessage()); + log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e); return Collections.emptyMap(); } return readQuantizationStateInfos; @@ -124,6 +124,9 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr default: throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); } + } catch (Exception e) { + log.warn(String.format("Unable to read the quantization state file for segment %s", segmentReadState.segmentInfo.name), e); + return null; } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index 21fde612fb..932d5cde06 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -45,9 +45,10 @@ public QuantizationState getQuantizationState(QuantizationStateReadConfig quanti .getQuantizationState(quantizationStateReadConfig.getCacheKey()); if (quantizationState == null) { quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); - addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); + if (quantizationState != null) { + addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); + } } - return quantizationState; }