diff --git a/CHANGELOG.md b/CHANGELOG.md index d18afce95c..70bcb63ce3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,4 +21,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Infrastructure ### Documentation ### Maintenance -### Refactoring \ No newline at end of file +### Refactoring diff --git a/release-notes/opensearch-knn.release-notes-2.17.0.0.md b/release-notes/opensearch-knn.release-notes-2.17.0.0.md index e042bc4e16..8dea9422b6 100644 --- a/release-notes/opensearch-knn.release-notes-2.17.0.0.md +++ b/release-notes/opensearch-knn.release-notes-2.17.0.0.md @@ -30,4 +30,5 @@ Compatible with OpenSearch 2.17.0 * Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) * Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) * Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957) -* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960) \ No newline at end of file +* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960) +* Add quantization state reader and writer [#1997](https://github.com/opensearch-project/k-NN/pull/1997) \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index d05079f4cb..441b4ea1cb 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -74,6 +74,7 @@ public class KNNConstants { public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature"; public static final String RADIAL_SEARCH_KEY = "radial_search"; + public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate"; // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index a70a17d858..4da11a2adc 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -24,7 +24,7 @@ import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto; import org.opensearch.knn.index.util.IndexHyperParametersUtil; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.os.OsProbe; @@ -60,6 +60,7 @@ public class KNNSettings { private static final OsProbe osProbe = OsProbe.getInstance(); private static final int INDEX_THREAD_QTY_MAX = 32; + private static final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance(); /** * Settings name @@ -379,11 +380,11 @@ private void setSettingsUpdateConsumers() { NativeMemoryCacheManager.getInstance().rebuildCache(builder.build()); }, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList())); clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, it -> { - QuantizationStateCache.getInstance().setMaxCacheSizeInKB(it.getKb()); - QuantizationStateCache.getInstance().rebuildCache(); + quantizationStateCacheManager.setMaxCacheSizeInKB(it.getKb()); + quantizationStateCacheManager.rebuildCache(); }); clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, it -> { - QuantizationStateCache.getInstance().rebuildCache(); + quantizationStateCacheManager.rebuildCache(); }); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java new file mode 100644 index 0000000000..5ae4e7b3b7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Reads quantization states + */ +@Log4j2 +public final class KNN990QuantizationStateReader { + + /** + * Read quantization states and return list of fieldNames and bytes + * File format: + * Header + * QS1 state bytes + * QS2 state bytes + * Number of quantization states + * QS1 field number + * QS1 state bytes length + * QS1 position of state bytes + * QS2 field number + * QS2 state bytes length + * QS2 position of state bytes + * Position of index section (where QS1 field name is located) + * -1 (marker) + * Footer + * + * @param state the read state to read from + */ + public static Map read(SegmentReadState state) throws IOException { + String quantizationStateFileName = getQuantizationStateFileName(state); + Map readQuantizationStateInfos = null; + + try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { + CodecUtil.retrieveChecksum(input); + + int numFields = getNumFields(input); + + readQuantizationStateInfos = new HashMap<>(); + + // Read each field's metadata from the index section and then read bytes + for (int i = 0; i < numFields; i++) { + int fieldNumber = input.readInt(); + int length = input.readInt(); + long position = input.readVLong(); + byte[] stateBytes = readStateBytes(input, position, length); + String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName(); + readQuantizationStateInfos.put(fieldName, stateBytes); + } + } catch (Exception e) { + log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e); + return Collections.emptyMap(); + } + return readQuantizationStateInfos; + } + + /** + * Reads an individual quantization state for a given field + * @param readConfig a config class that contains necessary information for reading the state + * @return quantization state + */ + public static QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { + SegmentReadState segmentReadState = readConfig.getSegmentReadState(); + String field = readConfig.getField(); + String quantizationStateFileName = getQuantizationStateFileName(segmentReadState); + int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); + + try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) { + CodecUtil.retrieveChecksum(input); + int numFields = getNumFields(input); + + long position = -1; + int length = 0; + + // Read each field's metadata from the index section, break when correct field is found + for (int i = 0; i < numFields; i++) { + int tempFieldNumber = input.readInt(); + int tempLength = input.readInt(); + long tempPosition = input.readVLong(); + if (tempFieldNumber == fieldNumber) { + position = tempPosition; + length = tempLength; + break; + } + } + + if (position == -1 || length == 0) { + throw new IllegalArgumentException(String.format("Field %s not found", field)); + } + + byte[] stateBytes = readStateBytes(input, position, length); + + // Deserialize the byte array to a quantization state object + ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType(); + switch (scalarQuantizationType) { + case ONE_BIT: + return OneBitScalarQuantizationState.fromByteArray(stateBytes); + case TWO_BIT: + case FOUR_BIT: + return MultiBitScalarQuantizationState.fromByteArray(stateBytes); + default: + throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); + } + } catch (Exception e) { + log.warn(String.format("Unable to read the quantization state file for segment %s", segmentReadState.segmentInfo.name), e); + return null; + } + } + + @VisibleForTesting + static int getNumFields(IndexInput input) throws IOException { + long footerStart = input.length() - CodecUtil.footerLength(); + long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; + input.seek(markerAndIndexPosition); + long indexStartPosition = input.readLong(); + input.seek(indexStartPosition); + return input.readInt(); + } + + @VisibleForTesting + static byte[] readStateBytes(IndexInput input, long position, int length) throws IOException { + input.seek(position); + byte[] stateBytes = new byte[length]; + input.readBytes(stateBytes, 0, length); + return stateBytes; + } + + @VisibleForTesting + static String getQuantizationStateFileName(SegmentReadState state) { + return IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java new file mode 100644 index 0000000000..49b1819c10 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.AllArgsConstructor; +import lombok.Setter; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.IndexOutput; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Writes quantization states to off heap memory + */ +public final class KNN990QuantizationStateWriter { + + private final IndexOutput output; + private List fieldQuantizationStates = new ArrayList<>(); + static final String NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA = "NativeEngines990KnnVectorsFormatQSData"; + + /** + * Constructor + * Overall file format for writer: + * Header + * QS1 state bytes + * QS2 state bytes + * Number of quantization states + * QS1 field number + * QS1 state bytes length + * QS1 position of state bytes + * QS2 field number + * QS2 state bytes length + * QS2 position of state bytes + * Position of index section (where QS1 field name is located) + * -1 (marker) + * Footer + * @param segmentWriteState segment write state containing segment information + * @throws IOException exception could be thrown while creating the output + */ + public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException { + String quantizationStateFileName = IndexFileNames.segmentFileName( + segmentWriteState.segmentInfo.name, + segmentWriteState.segmentSuffix, + KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX + ); + + output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context); + } + + /** + * Writes an index header + * @param segmentWriteState state containing segment information + * @throws IOException exception could be thrown while writing header + */ + public void writeHeader(SegmentWriteState segmentWriteState) throws IOException { + CodecUtil.writeIndexHeader( + output, + NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA, + 0, + segmentWriteState.segmentInfo.getId(), + segmentWriteState.segmentSuffix + ); + } + + /** + * Writes a quantization state as bytes + * + * @param fieldNumber field number + * @param quantizationState quantization state + * @throws IOException could be thrown while writing + */ + public void writeState(int fieldNumber, QuantizationState quantizationState) throws IOException { + byte[] stateBytes = quantizationState.toByteArray(); + long position = output.getFilePointer(); + output.writeBytes(stateBytes, stateBytes.length); + fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position)); + } + + /** + * Writes index footer and other index information for parsing later + * @throws IOException could be thrown while writing + */ + public void writeFooter() throws IOException { + long indexStartPosition = output.getFilePointer(); + output.writeInt(fieldQuantizationStates.size()); + for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) { + output.writeInt(fieldQuantizationState.fieldNumber); + output.writeInt(fieldQuantizationState.stateBytes.length); + output.writeVLong(fieldQuantizationState.position); + } + output.writeLong(indexStartPosition); + output.writeInt(-1); + CodecUtil.writeFooter(output); + } + + @AllArgsConstructor + private static class FieldQuantizationState { + final int fieldNumber; + final byte[] stateBytes; + @Setter + Long position; + } + + public void closeOutput() throws IOException { + output.close(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 74b158fa5a..06f705c1fa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -23,8 +23,19 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; +import org.opensearch.common.UUIDs; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; /** * Vectors reader class for reading the flat vectors for native engines. The class provides methods for iterating @@ -33,8 +44,13 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { private final FlatVectorsReader flatVectorsReader; + private final SegmentReadState segmentReadState; + private final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance(); + private Map quantizationStateCacheKeyPerField; - public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) { + public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException { + this.segmentReadState = state; + primeQuantizationStateCache(); this.flatVectorsReader = flatVectorsReader; } @@ -101,6 +117,22 @@ public ByteVectorValues getByteVectorValues(final String field) throws IOExcepti */ @Override public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + // TODO: This is a temporary hack where we are using KNNCollector to initialize the quantization state. + if (knnCollector instanceof QuantizationConfigKNNCollector) { + String cacheKey = quantizationStateCacheKeyPerField.get(field); + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(field); + QuantizationState quantizationState = QuantizationStateCacheManager.getInstance() + .getQuantizationState( + new QuantizationStateReadConfig( + segmentReadState, + QuantizationService.getInstance().getQuantizationParams(fieldInfo), + field, + cacheKey + ) + ); + ((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(quantizationState); + return; + } throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); } @@ -150,6 +182,9 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits @Override public void close() throws IOException { IOUtils.close(flatVectorsReader); + for (String cacheKey : quantizationStateCacheKeyPerField.values()) { + QuantizationStateCacheManager.getInstance().evict(cacheKey); + } } /** @@ -159,4 +194,31 @@ public void close() throws IOException { public long ramBytesUsed() { return flatVectorsReader.ramBytesUsed(); } + + private void primeQuantizationStateCache() throws IOException { + quantizationStateCacheKeyPerField = new HashMap<>(); + Map stateMap = KNN990QuantizationStateReader.read(segmentReadState); + for (Map.Entry entry : stateMap.entrySet()) { + FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(entry.getKey()); + QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); + if (quantizationParams instanceof ScalarQuantizationParams) { + QuantizationState quantizationState; + ScalarQuantizationParams scalarQuantizationParams = (ScalarQuantizationParams) quantizationParams; + switch (scalarQuantizationParams.getSqType()) { + case ONE_BIT: + quantizationState = OneBitScalarQuantizationState.fromByteArray(entry.getValue()); + break; + case TWO_BIT: + case FOUR_BIT: + quantizationState = MultiBitScalarQuantizationState.fromByteArray(entry.getValue()); + break; + default: + throw new IllegalArgumentException("Unknown Scalar Quantization Type"); + } + String cacheKey = UUIDs.base64UUID(); + quantizationStateCacheKeyPerField.put(entry.getKey(), cacheKey); + quantizationStateCacheManager.addQuantizationState(cacheKey, quantizationState); + } + } + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 1d3ff368aa..af7f1c5765 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -11,7 +11,6 @@ package org.opensearch.knn.index.codec.KNN990Codec; -import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; @@ -44,7 +43,6 @@ * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. */ @Log4j2 -@RequiredArgsConstructor public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); @@ -53,10 +51,16 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; + private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; private final QuantizationService quantizationService = QuantizationService.getInstance(); + public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { + this.segmentWriteState = segmentWriteState; + this.flatVectorsWriter = flatVectorsWriter; + } + /** * Add new field for indexing. * In Lucene, we use single file for all the vector fields so here we need to see how we are going to make things @@ -79,6 +83,7 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc @Override public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { flatVectorsWriter.flush(maxDoc, sortMap); + for (final NativeEngineFieldVectorsWriter field : fields) { trainAndIndex( field.getFieldInfo(), @@ -95,6 +100,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); + // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs trainAndIndex( fieldInfo, @@ -104,7 +110,6 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS, MERGE_OPERATION ); - } /** @@ -116,6 +121,9 @@ public void finish() throws IOException { throw new IllegalStateException("NativeEnginesKNNVectorsWriter is already finished"); } finished = true; + if (quantizationStateWriter != null) { + quantizationStateWriter.writeFooter(); + } flatVectorsWriter.finish(); } @@ -134,6 +142,9 @@ public void finish() throws IOException { */ @Override public void close() throws IOException { + if (quantizationStateWriter != null) { + quantizationStateWriter.closeOutput(); + } IOUtils.close(flatVectorsWriter); } @@ -238,7 +249,9 @@ private void trainAndIndex( QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; if (quantizationParams != null) { + initQuantizationStateWriterIfNecessary(); quantizationState = quantizationService.train(quantizationParams, knnVectorValues); + quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); } NativeIndexWriter writer = (quantizationParams != null) ? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState) @@ -253,4 +266,11 @@ private void trainAndIndex( graphBuildTime.incrementBy(time_in_millis); log.warn("Graph build took " + time_in_millis + " ms for " + operationName); } + + private void initQuantizationStateWriterIfNecessary() throws IOException { + if (quantizationStateWriter == null) { + quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); + quantizationStateWriter.writeHeader(segmentWriteState); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java new file mode 100644 index 0000000000..295b0fe585 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/QuantizationConfigKNNCollector.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.Getter; +import lombok.Setter; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +/** + * Collector used for passing the quantization state during query flow. + */ +@Setter +@Getter +public class QuantizationConfigKNNCollector implements KnnCollector { + + private QuantizationState quantizationState; + + private final String NATIVE_ENGINE_SEARCH_ERROR_MESSAGE = "Search functionality using codec is not supported with Native Engine Reader"; + + @Override + public boolean earlyTerminated() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public void incVisitedCount(int i) { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public long visitedCount() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public long visitLimit() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public int k() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public boolean collect(int i, float v) { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public float minCompetitiveSimilarity() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public TopDocs topDocs() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 6cc1108390..1769328fe6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -23,22 +23,24 @@ import org.apache.lucene.util.FixedBitSet; import org.opensearch.common.io.PathUtils; import org.opensearch.common.lucene.Lucene; -import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import java.io.IOException; import java.nio.file.Path; @@ -72,6 +74,7 @@ public class KNNWeight extends Weight { private final ExactSearcher exactSearcher; private static ExactSearcher DEFAULT_EXACT_SEARCHER; + private final QuantizationService quantizationService; public KNNWeight(KNNQuery query, float boost) { super(query); @@ -80,6 +83,7 @@ public KNNWeight(KNNQuery query, float boost) { this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); this.filterWeight = null; this.exactSearcher = DEFAULT_EXACT_SEARCHER; + this.quantizationService = QuantizationService.getInstance(); } public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { @@ -89,6 +93,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); this.filterWeight = filterWeight; this.exactSearcher = DEFAULT_EXACT_SEARCHER; + this.quantizationService = QuantizationService.getInstance(); } public static void initialize(ModelDao modelDao) { @@ -227,9 +232,6 @@ private Map doANNSearch( return null; } - // TODO: Use this to get quantization config - QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo); - KNNEngine knnEngine; SpaceType spaceType; VectorDataType vectorDataType; @@ -256,6 +258,11 @@ private Map doANNSearch( ); } + QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); + + // TODO: Change type of vector once more quantization methods are supported + byte[] quantizedVector = getQuantizedVector(quantizationParams, reader, fieldInfo); + List engineFiles = getEngineFiles(reader, knnEngine.getExtension()); if (engineFiles.isEmpty()) { log.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); @@ -273,7 +280,13 @@ private Map doANNSearch( new NativeMemoryEntryContext.IndexEntryContext( indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName(), vectorDataType), + getParametersAtLoading( + spaceType, + knnEngine, + knnQuery.getIndexName(), + // TODO: In the future, more vector data types will be supported with quantization + quantizationParams == null ? vectorDataType : VectorDataType.BINARY + ), knnQuery.getIndexName(), modelId ), @@ -296,10 +309,12 @@ private Map doANNSearch( } int[] parentIds = getParentIdsArray(context); if (k > 0) { - if (knnQuery.getVectorDataType() == VectorDataType.BINARY) { + if (knnQuery.getVectorDataType() == VectorDataType.BINARY + || quantizationParams != null && quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) { results = JNIService.queryBinaryIndex( indexAllocation.getMemoryAddress(), - knnQuery.getByteQueryVector(), + // TODO: In the future, quantizedVector can have other data types than byte + quantizationParams == null ? knnQuery.getByteQueryVector() : quantizedVector, k, knnQuery.getMethodParameters(), knnEngine, @@ -447,4 +462,23 @@ private boolean isExactSearchThresholdSettingSet(int filterThresholdValue) { private boolean canDoExactSearchAfterANNSearch(final int filterIdsCount, final int annResultCount) { return filterWeight != null && filterIdsCount >= knnQuery.getK() && knnQuery.getK() > annResultCount; } + + // TODO: this will eventually return more types than just byte + private byte[] getQuantizedVector(QuantizationParams quantizationParams, SegmentReader reader, FieldInfo fieldInfo) throws IOException { + if (quantizationParams != null) { + QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); + reader.searchNearestVectors(knnQuery.getField(), new float[0], tempCollector, null); + if (tempCollector.getQuantizationState() == null) { + throw new IllegalStateException(String.format("No quantization state found for field %s", fieldInfo.getName())); + } + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams); + // TODO: In the future, byte array will not be the only output type from this method + return (byte[]) quantizationService.quantize( + tempCollector.getQuantizationState(), + knnQuery.getQueryVector(), + quantizationOutput + ); + } + return null; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java index ba26d517d0..cc5a34bcd8 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java @@ -11,7 +11,6 @@ import com.google.common.cache.RemovalCause; import com.google.common.cache.RemovalNotification; import lombok.Getter; -import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.unit.ByteSizeValue; @@ -33,7 +32,6 @@ public class QuantizationStateCache { private static volatile QuantizationStateCache instance; private Cache cache; @Getter - @Setter private long maxCacheSizeInKB; @Getter private Instant evictedDueToSizeAt; @@ -48,7 +46,7 @@ public class QuantizationStateCache { * Gets the singleton instance of the cache. * @return QuantizationStateCache */ - public static QuantizationStateCache getInstance() { + static QuantizationStateCache getInstance() { if (instance == null) { synchronized (QuantizationStateCache.class) { if (instance == null) { @@ -75,7 +73,7 @@ private void buildCache() { .build(); } - public synchronized void rebuildCache() { + synchronized void rebuildCache() { clear(); buildCache(); } @@ -85,7 +83,7 @@ public synchronized void rebuildCache() { * @param fieldName The name of the field. * @return The associated QuantizationState, or null if not present. */ - public QuantizationState getQuantizationState(String fieldName) { + QuantizationState getQuantizationState(String fieldName) { return cache.getIfPresent(fieldName); } @@ -94,7 +92,7 @@ public QuantizationState getQuantizationState(String fieldName) { * @param fieldName The name of the field. * @param quantizationState The quantization state to store. */ - public void addQuantizationState(String fieldName, QuantizationState quantizationState) { + void addQuantizationState(String fieldName, QuantizationState quantizationState) { cache.put(fieldName, quantizationState); } @@ -112,6 +110,10 @@ private void onRemoval(RemovalNotification removalNot } } + void setMaxCacheSizeInKB(long maxCacheSizeInKB) { + this.maxCacheSizeInKB = maxCacheSizeInKB; + } + private void updateEvictedDueToSizeAt() { evictedDueToSizeAt = Instant.now(); } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java new file mode 100644 index 0000000000..932d5cde06 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateReader; + +import java.io.IOException; + +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class QuantizationStateCacheManager { + + private static volatile QuantizationStateCacheManager instance; + + /** + * Gets the singleton instance of the cache. + * @return QuantizationStateCache + */ + public static QuantizationStateCacheManager getInstance() { + if (instance == null) { + synchronized (QuantizationStateCacheManager.class) { + if (instance == null) { + instance = new QuantizationStateCacheManager(); + } + } + } + return instance; + } + + public synchronized void rebuildCache() { + QuantizationStateCache.getInstance().rebuildCache(); + } + + /** + * Retrieves the quantization state associated with a given field name. Reads from cache first, then from disk if necessary. + * @param quantizationStateReadConfig information required from reading from off-heap if necessary + * @return The associated QuantizationState + */ + public QuantizationState getQuantizationState(QuantizationStateReadConfig quantizationStateReadConfig) throws IOException { + QuantizationState quantizationState = QuantizationStateCache.getInstance() + .getQuantizationState(quantizationStateReadConfig.getCacheKey()); + if (quantizationState == null) { + quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); + if (quantizationState != null) { + addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); + } + } + return quantizationState; + } + + /** + * Adds or updates a quantization state in the cache. + * @param fieldName The name of the field. + * @param quantizationState The quantization state to store. + */ + public void addQuantizationState(String fieldName, QuantizationState quantizationState) { + QuantizationStateCache.getInstance().addQuantizationState(fieldName, quantizationState); + } + + /** + * Removes the quantization state associated with a given field name. + * @param fieldName The name of the field. + */ + public void evict(String fieldName) { + QuantizationStateCache.getInstance().evict(fieldName); + } + + public void setMaxCacheSizeInKB(long maxCacheSizeInKB) { + QuantizationStateCache.getInstance().setMaxCacheSizeInKB(maxCacheSizeInKB); + } + + /** + * Clears all entries from the cache. + */ + public void clear() { + QuantizationStateCache.getInstance().clear(); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java new file mode 100644 index 0000000000..d13e4f3f52 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateReadConfig.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.index.SegmentReadState; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +@Getter +@AllArgsConstructor +public class QuantizationStateReadConfig { + private SegmentReadState segmentReadState; + private QuantizationParams quantizationParams; + private String field; + private String cacheKey; +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java new file mode 100644 index 0000000000..2801560622 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java @@ -0,0 +1,236 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Version; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; + +public class KNN990QuantizationStateReaderTests extends KNNTestCase { + + @SneakyThrows + public void testReadFromSegmentReadState() { + final String segmentName = "test-segment-name"; + final String segmentSuffix = "test-segment-suffix"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexInput input = Mockito.mock(IndexInput.class); + Mockito.when(directory.openInput(any(), any())).thenReturn(input); + + String fieldName = "test-field"; + FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + Mockito.when(fieldInfo.getName()).thenReturn(fieldName); + Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); + + final SegmentReadState segmentReadState = new SegmentReadState( + directory, + segmentInfo, + fieldInfos, + Mockito.mock(IOContext.class), + segmentSuffix + ); + + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNN990QuantizationStateReader.getNumFields(input)).thenReturn(2); + mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(segmentReadState)).thenCallRealMethod(); + try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { + KNN990QuantizationStateReader.read(segmentReadState); + + mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); + Mockito.verify(input, times(4)).readInt(); + Mockito.verify(input, times(2)).readVLong(); + } + } + } + + @SneakyThrows + public void testReadFromQuantizationStateReadConfig() { + String fieldName = "test-field"; + int fieldNumber = 4; + FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + Mockito.when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + Mockito.when(fieldInfos.fieldInfo(fieldName)).thenReturn(fieldInfo); + + final String segmentName = "test-segment-name"; + final String segmentSuffix = "test-segment-suffix"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexInput input = Mockito.mock(IndexInput.class); + Mockito.when(directory.openInput(any(), any())).thenReturn(input); + + final SegmentReadState segmentReadState = new SegmentReadState( + directory, + segmentInfo, + fieldInfos, + Mockito.mock(IOContext.class), + segmentSuffix + ); + ScalarQuantizationParams scalarQuantizationParams1 = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams scalarQuantizationParams2 = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams scalarQuantizationParams4 = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + QuantizationStateReadConfig quantizationStateReadConfig = Mockito.mock(QuantizationStateReadConfig.class); + Mockito.when(quantizationStateReadConfig.getSegmentReadState()).thenReturn(segmentReadState); + Mockito.when(quantizationStateReadConfig.getField()).thenReturn(fieldName); + + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNN990QuantizationStateReader.getNumFields(input)).thenReturn(2); + mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(quantizationStateReadConfig)).thenCallRealMethod(); + mockedStaticReader.when(() -> KNN990QuantizationStateReader.readStateBytes(any(IndexInput.class), anyLong(), anyInt())) + .thenReturn(new byte[8]); + try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { + assertThrows(IllegalArgumentException.class, () -> KNN990QuantizationStateReader.read(quantizationStateReadConfig)); + + mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); + Mockito.verify(input, times(4)).readInt(); + Mockito.verify(input, times(2)).readVLong(); + + Mockito.when(input.readInt()).thenReturn(fieldNumber); + + try (MockedStatic mockedStaticOneBit = mockStatic(OneBitScalarQuantizationState.class)) { + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams1); + OneBitScalarQuantizationState oneBitScalarQuantizationState = Mockito.mock(OneBitScalarQuantizationState.class); + mockedStaticOneBit.when(() -> OneBitScalarQuantizationState.fromByteArray(any(byte[].class))) + .thenReturn(oneBitScalarQuantizationState); + QuantizationState quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); + assertEquals(oneBitScalarQuantizationState, quantizationState); + } + + try (MockedStatic mockedStaticOneBit = mockStatic(MultiBitScalarQuantizationState.class)) { + MultiBitScalarQuantizationState multiBitScalarQuantizationState = Mockito.mock(MultiBitScalarQuantizationState.class); + mockedStaticOneBit.when(() -> MultiBitScalarQuantizationState.fromByteArray(any(byte[].class))) + .thenReturn(multiBitScalarQuantizationState); + + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams2); + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams2); + QuantizationState quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); + assertEquals(multiBitScalarQuantizationState, quantizationState); + + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams4); + Mockito.when(quantizationStateReadConfig.getQuantizationParams()).thenReturn(scalarQuantizationParams4); + quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); + assertEquals(multiBitScalarQuantizationState, quantizationState); + } + } + } + } + + @SneakyThrows + public void testGetNumFields() { + IndexInput input = Mockito.mock(IndexInput.class); + KNN990QuantizationStateReader.getNumFields(input); + + Mockito.verify(input, times(1)).readInt(); + Mockito.verify(input, times(1)).readLong(); + Mockito.verify(input, times(2)).seek(anyLong()); + Mockito.verify(input, times(1)).length(); + } + + @SneakyThrows + public void testReadStateBytes() { + IndexInput input = Mockito.mock(IndexInput.class); + long position = 1; + int length = 2; + byte[] stateBytes = new byte[length]; + KNN990QuantizationStateReader.readStateBytes(input, position, length); + + Mockito.verify(input, times(1)).seek(position); + Mockito.verify(input, times(1)).readBytes(stateBytes, 0, length); + + } + + @SneakyThrows + public void testGetQuantizationStateFileName() { + String segmentName = "test-segment"; + String segmentSuffix = "test-suffix"; + String expectedName = IndexFileNames.segmentFileName(segmentName, segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + final SegmentReadState segmentReadState = new SegmentReadState( + Mockito.mock(Directory.class), + segmentInfo, + Mockito.mock(FieldInfos.class), + Mockito.mock(IOContext.class), + segmentSuffix + ); + + assertEquals(expectedName, KNN990QuantizationStateReader.getQuantizationStateFileName(segmentReadState)); + + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java new file mode 100644 index 0000000000..2423a68277 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java @@ -0,0 +1,187 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.Version; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; + +public class KNN990QuantizationStateWriterTests extends KNNTestCase { + + @SneakyThrows + public void testWriteHeader() { + final String segmentName = "test-segment-name"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexOutput output = Mockito.mock(IndexOutput.class); + Mockito.when(directory.createOutput(any(), any())).thenReturn(output); + + final SegmentWriteState segmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); + KNN990QuantizationStateWriter quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); + try (MockedStatic mockedStaticCodecUtil = Mockito.mockStatic(CodecUtil.class)) { + mockedStaticCodecUtil.when( + () -> CodecUtil.writeIndexHeader(any(IndexOutput.class), anyString(), anyInt(), any(byte[].class), anyString()) + ).thenAnswer((Answer) invocation -> null); + quantizationStateWriter.writeHeader(segmentWriteState); + mockedStaticCodecUtil.verify( + () -> CodecUtil.writeIndexHeader( + output, + KNN990QuantizationStateWriter.NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA, + 0, + segmentWriteState.segmentInfo.getId(), + segmentWriteState.segmentSuffix + ) + ); + } + } + + @SneakyThrows + public void testWriteState() { + final String segmentName = "test-segment-name"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexOutput output = Mockito.mock(IndexOutput.class); + Mockito.when(directory.createOutput(any(), any())).thenReturn(output); + + final SegmentWriteState segmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); + KNN990QuantizationStateWriter quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); + + int fieldNumber = 0; + QuantizationState quantizationState = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f, 4.5f } + ); + quantizationStateWriter.writeState(fieldNumber, quantizationState); + byte[] stateBytes = quantizationState.toByteArray(); + Mockito.verify(output, times(1)).writeBytes(stateBytes, stateBytes.length); + } + + @SneakyThrows + public void testWriteFooter() { + final String segmentName = "test-segment-name"; + + final SegmentInfo segmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + Directory directory = Mockito.mock(Directory.class); + IndexOutput output = Mockito.mock(IndexOutput.class); + Mockito.when(directory.createOutput(any(), any())).thenReturn(output); + + final SegmentWriteState segmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + directory, + segmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); + KNN990QuantizationStateWriter quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); + + int fieldNumber1 = 1; + int fieldNumber2 = 2; + QuantizationState quantizationState1 = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f, 4.5f } + ); + QuantizationState quantizationState2 = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 2.3f, 3.4f, 4.5f, 5.6f } + ); + quantizationStateWriter.writeState(fieldNumber1, quantizationState1); + quantizationStateWriter.writeState(fieldNumber2, quantizationState2); + + try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { + quantizationStateWriter.writeFooter(); + + Mockito.verify(output, times(6)).writeInt(anyInt()); + Mockito.verify(output, times(2)).writeVLong(anyLong()); + Mockito.verify(output, times(1)).writeLong(anyLong()); + mockedStaticCodecUtil.verify(() -> CodecUtil.writeFooter(output)); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index f4f02ba918..21bd4c1bd1 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -14,6 +14,7 @@ import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; @@ -24,25 +25,36 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReader; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Sort; import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.store.BaseDirectoryWrapper; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.Version; import org.junit.After; import org.junit.Assert; +import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.mockito.stubbing.Answer; import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; @@ -52,14 +64,18 @@ import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; + @Log4j2 public class NativeEngines990KnnVectorsFormatTests extends KNNTestCase { private static final Codec TESTING_CODEC = new UnitTestCodec(); @@ -82,21 +98,72 @@ public void tearDown() throws Exception { @SneakyThrows public void testReaderAndWriter_whenValidInput_thenSuccess() { final Lucene99FlatVectorsFormat mockedFlatVectorsFormat = Mockito.mock(Lucene99FlatVectorsFormat.class); - final SegmentWriteState mockedSegmentWriteState = Mockito.mock(SegmentWriteState.class); - final SegmentReadState mockedSegmentReadState = Mockito.mock(SegmentReadState.class); + final String segmentName = "test-segment-name"; + + final SegmentInfo mockedSegmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + final String segmentSuffix = "test-segment-suffix"; + + Directory directory = Mockito.mock(Directory.class); + IndexInput input = Mockito.mock(IndexInput.class); + Mockito.when(directory.openInput(any(), any())).thenReturn(input); + + String fieldName = "test-field"; + FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + Mockito.when(fieldInfo.getName()).thenReturn(fieldName); + Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); + + final SegmentReadState mockedSegmentReadState = new SegmentReadState( + directory, + mockedSegmentInfo, + fieldInfos, + Mockito.mock(IOContext.class), + segmentSuffix + ); + + final SegmentWriteState mockedSegmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + Mockito.mock(Directory.class), + mockedSegmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); Mockito.when(mockedFlatVectorsFormat.fieldsReader(mockedSegmentReadState)).thenReturn(Mockito.mock(FlatVectorsReader.class)); Mockito.when(mockedFlatVectorsFormat.fieldsWriter(mockedSegmentWriteState)).thenReturn(Mockito.mock(FlatVectorsWriter.class)); final NativeEngines990KnnVectorsFormat nativeEngines990KnnVectorsFormat = new NativeEngines990KnnVectorsFormat( mockedFlatVectorsFormat ); - Assert.assertTrue( - nativeEngines990KnnVectorsFormat.fieldsReader(mockedSegmentReadState) instanceof NativeEngines990KnnVectorsReader - ); - Assert.assertTrue( - nativeEngines990KnnVectorsFormat.fieldsWriter(mockedSegmentWriteState) instanceof NativeEngines990KnnVectorsWriter - ); + try (MockedStatic mockedStaticCodecUtil = Mockito.mockStatic(CodecUtil.class)) { + mockedStaticCodecUtil.when( + () -> CodecUtil.writeIndexHeader(any(IndexOutput.class), anyString(), anyInt(), any(byte[].class), anyString()) + ).thenAnswer((Answer) invocation -> null); + mockedStaticCodecUtil.when(() -> CodecUtil.retrieveChecksum(any(IndexInput.class))) + .thenAnswer((Answer) invocation -> null); + Assert.assertTrue( + nativeEngines990KnnVectorsFormat.fieldsReader(mockedSegmentReadState) instanceof NativeEngines990KnnVectorsReader + ); + + Assert.assertTrue( + nativeEngines990KnnVectorsFormat.fieldsWriter(mockedSegmentWriteState) instanceof NativeEngines990KnnVectorsWriter + ); + } } @SneakyThrows @@ -137,8 +204,6 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc indexWriter.commit(); indexWriter.close(); - assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); - // Validate to see if correct values are returned, assumption here is only 1 segment is getting created IndexSearcher searcher = new IndexSearcher(indexReader); final LeafReader leafReader = searcher.getLeafContexts().get(0).reader(); @@ -208,7 +273,6 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce indexWriter.flush(); indexWriter.commit(); indexWriter.close(); - assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); IndexSearcher searcher = new IndexSearcher(indexReader); final LeafReader leafReader = searcher.getLeafContexts().get(0).reader(); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 249ae04cec..a2b41804a5 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -32,6 +32,7 @@ import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; +import org.mockito.MockedConstruction; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.common.io.PathUtils; @@ -39,6 +40,7 @@ import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -47,6 +49,7 @@ import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; @@ -54,6 +57,11 @@ import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.nio.file.Path; @@ -1064,7 +1072,7 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { .parentsFilter(bitSetProducer) .build(); - final KNNWeight knnWeight = new KNNWeight(query, 0.0f, null); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f); jniServiceMockedStatic.when( () -> JNIService.queryIndex( @@ -1350,4 +1358,157 @@ private KNNQueryResult[] getFilteredKNNQueryResults() { .collect(Collectors.toList()) .toArray(new KNNQueryResult[0]); } + + @SneakyThrows + public void testANNWithQuantizationParams_whenStateNotFound_thenFail() { + try (MockedStatic quantizationServiceMockedStatic = Mockito.mockStatic(QuantizationService.class)) { + QuantizationService quantizationService = Mockito.mock(QuantizationService.class); + QuantizationParams quantizationParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + Mockito.when(quantizationService.getQuantizationParams(any(FieldInfo.class))).thenReturn(quantizationParams); + quantizationServiceMockedStatic.when(QuantizationService::getInstance).thenReturn(quantizationService); + + // Given + int k = 3; + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + jniServiceMockedStatic.when( + () -> JNIService.queryBinaryIndex( + anyLong(), + eq(BYTE_QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + final SegmentReader reader = mockSegmentReader(); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .build(); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + expectThrows(IllegalStateException.class, () -> knnWeight.scorer(leafReaderContext)); + } + } + + @SneakyThrows + public void testANNWithQuantizationParams_thenSuccess() { + try (MockedStatic quantizationServiceMockedStatic = Mockito.mockStatic(QuantizationService.class)) { + QuantizationService quantizationService = Mockito.mock(QuantizationService.class); + ScalarQuantizationParams quantizationParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + Mockito.when(quantizationService.getQuantizationParams(any(FieldInfo.class))).thenReturn(quantizationParams); + quantizationServiceMockedStatic.when(QuantizationService::getInstance).thenReturn(quantizationService); + + float[] meanThresholds = new float[] { 1.2f, 2.3f, 3.4f, 4.5f }; + QuantizationState quantizationState = new OneBitScalarQuantizationState(quantizationParams, meanThresholds); + + try ( + MockedConstruction quantizationCollectorMockedConstruction = Mockito.mockConstruction( + QuantizationConfigKNNCollector.class, + (mock, context) -> Mockito.when(mock.getQuantizationState()).thenReturn(quantizationState) + ) + ) { + + // Given + int k = 3; + jniServiceMockedStatic.when( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + + jniServiceMockedStatic.when( + () -> JNIService.queryBinaryIndex( + anyLong(), + eq(BYTE_QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + final SegmentReader reader = mockSegmentReader(); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .build(); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + assertNotNull(knnScorer); + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + any() + ), + times(1) + ); + } + } + } } diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java new file mode 100644 index 0000000000..14e55e627d --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.SneakyThrows; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateReader; + +import static org.mockito.Mockito.times; + +public class QuantizationStateCacheManagerTests extends KNNTestCase { + + @SneakyThrows + public void testRebuildCache() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).rebuildCache(); + QuantizationStateCacheManager.getInstance().rebuildCache(); + Mockito.verify(quantizationStateCache, times(1)).rebuildCache(); + } + } + + @SneakyThrows + public void testGetQuantizationState() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + QuantizationStateReadConfig quantizationStateReadConfig = Mockito.mock(QuantizationStateReadConfig.class); + String cacheKey = "test-key"; + Mockito.when(quantizationStateReadConfig.getCacheKey()).thenReturn(cacheKey); + QuantizationState quantizationState = Mockito.mock(QuantizationState.class); + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).addQuantizationState(cacheKey, quantizationState); + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { + mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(quantizationStateReadConfig)) + .thenReturn(quantizationState); + QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); + Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(cacheKey, quantizationState); + } + Mockito.when(quantizationStateCache.getQuantizationState(cacheKey)).thenReturn(quantizationState); + QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); + Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(cacheKey, quantizationState); + } + } + + @SneakyThrows + public void testEvict() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + String field = "test-field"; + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).evict(field); + QuantizationStateCacheManager.getInstance().evict(field); + Mockito.verify(quantizationStateCache, times(1)).evict(field); + } + } + + @SneakyThrows + public void testAddQuantizationState() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + String field = "test-field"; + QuantizationState quantizationState = Mockito.mock(QuantizationState.class); + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).addQuantizationState(field, quantizationState); + QuantizationStateCacheManager.getInstance().addQuantizationState(field, quantizationState); + Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(field, quantizationState); + } + } + + @SneakyThrows + public void testSetMaxCacheSizeInKB() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + long maxCacheSizeInKB = 1024; + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).setMaxCacheSizeInKB(maxCacheSizeInKB); + QuantizationStateCacheManager.getInstance().setMaxCacheSizeInKB(1024); + Mockito.verify(quantizationStateCache, times(1)).setMaxCacheSizeInKB(1024); + } + } + + @SneakyThrows + public void testClear() { + try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { + QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + Mockito.doNothing().when(quantizationStateCache).clear(); + QuantizationStateCacheManager.getInstance().clear(); + Mockito.verify(quantizationStateCache, times(1)).clear(); + } + } +}