From 6b27be7272ef61ccab44f97496b6f005f9821c84 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 17 Jul 2024 09:34:06 -0700 Subject: [PATCH 1/7] Fix graph merge stats size calculation Signed-off-by: Ryan Bogan --- .../knn/index/codec/util/KNNCodecUtil.java | 31 ++++++++----------- .../index/codec/util/KNNCodecUtilTests.java | 18 +++++++++++ 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 04aeb337fd..b3106c9b59 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -70,39 +70,34 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { if (serializationMode == SerializationMode.ARRAY) { int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } + vectorSize = roundVectorSize(vectorSize); int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE; - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } + vectorsSize = roundVectorSize(vectorsSize); return vectorsSize; } else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) { int vectorSize = vectorLength * FLOAT_BYTE_SIZE; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } + vectorSize = roundVectorSize(vectorSize); int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } + vectorsSize = roundVectorSize(vectorsSize); return vectorsSize; } else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) { int vectorSize = vectorLength; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } + vectorSize = roundVectorSize(vectorSize); int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } + vectorsSize = roundVectorSize(vectorsSize); return vectorsSize; } else { throw new IllegalStateException("Unreachable code"); } } + private static int roundVectorSize(int vectorSize) { + if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { + return vectorSize + (JAVA_ROUNDING_NUMBER - vectorSize % JAVA_ROUNDING_NUMBER); + } + return vectorSize; + } + public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName, String extension) { return String.format("%s%s%s", buildEngineFilePrefix(segmentName), latestBuildVersion, buildEngineFileSuffix(fieldName, extension)); } diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 2ff0f08e51..70eef70fea 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -18,6 +18,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; public class KNNCodecUtilTests extends TestCase { @SneakyThrows @@ -52,4 +53,21 @@ public void testGetPair_whenCalled_thenReturn() { assertEquals(dimension, pair.getDimension()); assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); } + + public void testCalculateArraySize() { + int numVectors = 4; + int vectorLength = 10; + + // Array + SerializationMode serializationMode = SerializationMode.ARRAY; + assertEquals(256, calculateArraySize(numVectors, vectorLength, serializationMode)); + + // Collection of floats + serializationMode = SerializationMode.COLLECTION_OF_FLOATS; + assertEquals(176, calculateArraySize(numVectors, vectorLength, serializationMode)); + + // Collection of bytes + serializationMode = SerializationMode.COLLECTIONS_OF_BYTES; + assertEquals(80, calculateArraySize(numVectors, vectorLength, serializationMode)); + } } From b0863f248ec60ef401ad5f3a3acbb08a63bde255 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 17 Jul 2024 09:38:21 -0700 Subject: [PATCH 2/7] Add changelog entry Signed-off-by: Ryan Bogan --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29c0f1841a..0ce3e47362 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Fixed LeafReaders casting errors to SegmentReaders when segment replication is enabled during search.[#1808](https://github.com/opensearch-project/k-NN/pull/1808) * Release memory properly for an array type [#1820](https://github.com/opensearch-project/k-NN/pull/1820) * FIX Same Suffix Cause Recall Drop to zero [#1802](https://github.com/opensearch-project/k-NN/pull/1802) +* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) ### Infrastructure * Apply custom patch only once by comparing the last patch id [#1833](https://github.com/opensearch-project/k-NN/pull/1833) ### Documentation From 609301e9f5bb7f7498beeaf40f6399a8b062b81d Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 18 Jul 2024 10:51:19 -0700 Subject: [PATCH 3/7] Add javadocs Signed-off-by: Ryan Bogan --- .../org/opensearch/knn/index/codec/util/KNNCodecUtil.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index b3106c9b59..892fc28155 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -67,6 +67,13 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect ); } + /** + * This method provides a rough estimate of the number of bytes used for storing an array with the given parameters. + * @param numVectors number of vectors in the array + * @param vectorLength the length of each vector + * @param serializationMode serialization mode + * @return rough estimate of number of bytes used to store an array with the given parameters + */ public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { if (serializationMode == SerializationMode.ARRAY) { int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; From c1e3109f7bb56700c24b6a83f4d602ef97656f57 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 18 Jul 2024 12:13:36 -0700 Subject: [PATCH 4/7] Make calculations easier to read Signed-off-by: Ryan Bogan --- .../knn/index/codec/util/KNNCodecUtil.java | 36 ++++++++++--------- .../index/codec/util/KNNCodecUtilTests.java | 6 ++-- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 892fc28155..7eaa72acba 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -21,8 +21,8 @@ public class KNNCodecUtil { // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; - // References to objects are 4 bytes in size - public static final int JAVA_REFERENCE_SIZE = 4; + // References to objects are 8 bytes in size + public static final int JAVA_REFERENCE_SIZE = 8; // Each array in Java has a header that is 12 bytes public static final int JAVA_ARRAY_HEADER_SIZE = 12; // Java rounds each array size up to multiples of 8 bytes @@ -75,24 +75,26 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect * @return rough estimate of number of bytes used to store an array with the given parameters */ public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { + // For more information about array storage in Java, visit https://www.javamex.com/tutorials/memory/array_memory_usage.shtml + // Note: java reference size is 8 bytes for 64 bit machines and 4 bytes for 32 bit machines, this method assumes 64 bit if (serializationMode == SerializationMode.ARRAY) { - int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; - vectorSize = roundVectorSize(vectorSize); - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE; - vectorsSize = roundVectorSize(vectorsSize); - return vectorsSize; + int sizeOfVector = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; + int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors; + int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE + JAVA_ARRAY_HEADER_SIZE; + sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray); + return sizeOfReferenceArray + sizeOfVectorArray; } else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) { - int vectorSize = vectorLength * FLOAT_BYTE_SIZE; - vectorSize = roundVectorSize(vectorSize); - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - vectorsSize = roundVectorSize(vectorsSize); - return vectorsSize; + int sizeOfVector = vectorLength * FLOAT_BYTE_SIZE; + int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors; + int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE; + sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray); + return sizeOfReferenceArray + sizeOfVectorArray; } else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) { - int vectorSize = vectorLength; - vectorSize = roundVectorSize(vectorSize); - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - vectorsSize = roundVectorSize(vectorsSize); - return vectorsSize; + int sizeOfVector = vectorLength; + int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors; + int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE; + sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray); + return sizeOfReferenceArray + sizeOfVectorArray; } else { throw new IllegalStateException("Unreachable code"); } diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 70eef70fea..25affc48cd 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -60,14 +60,14 @@ public void testCalculateArraySize() { // Array SerializationMode serializationMode = SerializationMode.ARRAY; - assertEquals(256, calculateArraySize(numVectors, vectorLength, serializationMode)); + assertEquals(272, calculateArraySize(numVectors, vectorLength, serializationMode)); // Collection of floats serializationMode = SerializationMode.COLLECTION_OF_FLOATS; - assertEquals(176, calculateArraySize(numVectors, vectorLength, serializationMode)); + assertEquals(192, calculateArraySize(numVectors, vectorLength, serializationMode)); // Collection of bytes serializationMode = SerializationMode.COLLECTIONS_OF_BYTES; - assertEquals(80, calculateArraySize(numVectors, vectorLength, serializationMode)); + assertEquals(96, calculateArraySize(numVectors, vectorLength, serializationMode)); } } From 1db8365bcbb5e96a283b6a99a889c47b629d23fd Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Fri, 26 Jul 2024 10:41:56 -0700 Subject: [PATCH 5/7] Remove java overhead from calculations Signed-off-by: Ryan Bogan --- .../knn/index/codec/util/KNNCodecUtil.java | 37 ++----------------- .../index/codec/util/KNNCodecUtilTests.java | 6 +-- 2 files changed, 6 insertions(+), 37 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 7eaa72acba..585a9eac0a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -21,12 +21,6 @@ public class KNNCodecUtil { // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; - // References to objects are 8 bytes in size - public static final int JAVA_REFERENCE_SIZE = 8; - // Each array in Java has a header that is 12 bytes - public static final int JAVA_ARRAY_HEADER_SIZE = 12; - // Java rounds each array size up to multiples of 8 bytes - public static final int JAVA_ROUNDING_NUMBER = 8; @AllArgsConstructor public static final class Pair { @@ -75,38 +69,13 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect * @return rough estimate of number of bytes used to store an array with the given parameters */ public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { - // For more information about array storage in Java, visit https://www.javamex.com/tutorials/memory/array_memory_usage.shtml - // Note: java reference size is 8 bytes for 64 bit machines and 4 bytes for 32 bit machines, this method assumes 64 bit - if (serializationMode == SerializationMode.ARRAY) { - int sizeOfVector = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; - int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors; - int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE + JAVA_ARRAY_HEADER_SIZE; - sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray); - return sizeOfReferenceArray + sizeOfVectorArray; - } else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) { - int sizeOfVector = vectorLength * FLOAT_BYTE_SIZE; - int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors; - int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE; - sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray); - return sizeOfReferenceArray + sizeOfVectorArray; - } else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) { - int sizeOfVector = vectorLength; - int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors; - int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE; - sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray); - return sizeOfReferenceArray + sizeOfVectorArray; + if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) { + return numVectors * vectorLength; } else { - throw new IllegalStateException("Unreachable code"); + return numVectors * vectorLength * FLOAT_BYTE_SIZE; } } - private static int roundVectorSize(int vectorSize) { - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - return vectorSize + (JAVA_ROUNDING_NUMBER - vectorSize % JAVA_ROUNDING_NUMBER); - } - return vectorSize; - } - public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName, String extension) { return String.format("%s%s%s", buildEngineFilePrefix(segmentName), latestBuildVersion, buildEngineFileSuffix(fieldName, extension)); } diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 25affc48cd..3274bb731b 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -60,14 +60,14 @@ public void testCalculateArraySize() { // Array SerializationMode serializationMode = SerializationMode.ARRAY; - assertEquals(272, calculateArraySize(numVectors, vectorLength, serializationMode)); + assertEquals(160, calculateArraySize(numVectors, vectorLength, serializationMode)); // Collection of floats serializationMode = SerializationMode.COLLECTION_OF_FLOATS; - assertEquals(192, calculateArraySize(numVectors, vectorLength, serializationMode)); + assertEquals(160, calculateArraySize(numVectors, vectorLength, serializationMode)); // Collection of bytes serializationMode = SerializationMode.COLLECTIONS_OF_BYTES; - assertEquals(96, calculateArraySize(numVectors, vectorLength, serializationMode)); + assertEquals(40, calculateArraySize(numVectors, vectorLength, serializationMode)); } } From a3a1dc17f371e37ee2ca171d91bd74133dfe534b Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 6 Aug 2024 11:59:59 -0700 Subject: [PATCH 6/7] Change from serialization mode to vector data type for calculations Signed-off-by: Ryan Bogan --- .../KNN80Codec/KNN80DocValuesConsumer.java | 7 ++++--- .../knn/index/codec/util/KNNCodecUtil.java | 11 ++++++----- .../index/codec/util/KNNCodecUtilTests.java | 19 ++++++++++--------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 8e191ac5f3..989af4063b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -131,6 +131,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, NativeIndexCreator indexCreator; KNNCodecUtil.Pair pair; Map fieldAttributes = field.attributes(); + VectorDataType vectorDataType; if (fieldAttributes.containsKey(MODEL_ID)) { String modelId = fieldAttributes.get(MODEL_ID); @@ -138,12 +139,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType(); + vectorDataType = model.getModelMetadata().getVectorDataType(); pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); } else { // get vector data type from field attributes or provide default value - VectorDataType vectorDataType = VectorDataType.get( + vectorDataType = VectorDataType.get( fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) ); pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); @@ -156,7 +157,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, return; } - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); + long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType); if (isMerge) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 83a6095ee4..3133331f49 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; import org.opensearch.knn.index.codec.transfer.VectorTransfer; @@ -65,14 +66,14 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect * This method provides a rough estimate of the number of bytes used for storing an array with the given parameters. * @param numVectors number of vectors in the array * @param vectorLength the length of each vector - * @param serializationMode serialization mode + * @param vectorDataType type of data stored in each vector * @return rough estimate of number of bytes used to store an array with the given parameters */ - public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { - if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) { - return numVectors * vectorLength; - } else { + public static long calculateArraySize(int numVectors, int vectorLength, VectorDataType vectorDataType) { + if (vectorDataType == VectorDataType.FLOAT) { return numVectors * vectorLength * FLOAT_BYTE_SIZE; + } else { + return numVectors * vectorLength; } } diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 3274bb731b..47dd1dda99 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -9,6 +9,7 @@ import lombok.SneakyThrows; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.transfer.VectorTransfer; import java.util.Arrays; @@ -58,16 +59,16 @@ public void testCalculateArraySize() { int numVectors = 4; int vectorLength = 10; - // Array - SerializationMode serializationMode = SerializationMode.ARRAY; - assertEquals(160, calculateArraySize(numVectors, vectorLength, serializationMode)); + // Float data type + VectorDataType vectorDataType = VectorDataType.FLOAT; + assertEquals(160, calculateArraySize(numVectors, vectorLength, vectorDataType)); - // Collection of floats - serializationMode = SerializationMode.COLLECTION_OF_FLOATS; - assertEquals(160, calculateArraySize(numVectors, vectorLength, serializationMode)); + // Byte data type + vectorDataType = VectorDataType.BYTE; + assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType)); - // Collection of bytes - serializationMode = SerializationMode.COLLECTIONS_OF_BYTES; - assertEquals(40, calculateArraySize(numVectors, vectorLength, serializationMode)); + // Binary data type + vectorDataType = VectorDataType.BINARY; + assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType)); } } From acddd8723b4d139a319ca6a0ddb0d892b1e29daf Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 7 Aug 2024 09:32:40 -0700 Subject: [PATCH 7/7] Minor change to if statements Signed-off-by: Ryan Bogan --- CHANGELOG.md | 2 +- .../org/opensearch/knn/index/codec/util/KNNCodecUtil.java | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6309105f48..44f387533a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) -* * Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) +* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 3133331f49..ea14fe8834 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -72,8 +72,12 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect public static long calculateArraySize(int numVectors, int vectorLength, VectorDataType vectorDataType) { if (vectorDataType == VectorDataType.FLOAT) { return numVectors * vectorLength * FLOAT_BYTE_SIZE; - } else { + } else if (vectorDataType == VectorDataType.BINARY || vectorDataType == VectorDataType.BYTE) { return numVectors * vectorLength; + } else { + throw new IllegalArgumentException( + "Float, binary, and byte are the only supported vector data types for array size calculation." + ); } }