Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.mapper;

import java.util.Arrays;

import org.apache.lucene.document.StoredField;
import org.apache.lucene.util.BytesRef;
import org.junit.Assert;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.opensearch.Version;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

public class KNNVectorFieldMapperUtilTests extends KNNTestCase {

private static final String TEST_FIELD_NAME = "test_field_name";
private static final byte[] TEST_BYTE_VECTOR = new byte[] { -128, 0, 1, 127 };
private static final float[] TEST_FLOAT_VECTOR = new float[] { -100.0f, 100.0f, 0f, 1f };

public void testStoredFields_whenVectorIsByteType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForByteVector(TEST_FIELD_NAME, TEST_BYTE_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
assertEquals(TEST_BYTE_VECTOR, storedField.binaryValue().bytes);
Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.BYTE);
assertTrue(vector instanceof int[]);
int[] byteAsIntArray = new int[TEST_BYTE_VECTOR.length];
Arrays.setAll(byteAsIntArray, i -> TEST_BYTE_VECTOR[i]);
assertArrayEquals(byteAsIntArray, (int[]) vector);
}

public void testStoredFields_whenVectorIsBinaryType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForByteVector(TEST_FIELD_NAME, TEST_BYTE_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
assertEquals(TEST_BYTE_VECTOR, storedField.binaryValue().bytes);
Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.BINARY);
assertTrue(vector instanceof int[]);
int[] byteAsIntArray = new int[TEST_BYTE_VECTOR.length];
Arrays.setAll(byteAsIntArray, i -> TEST_BYTE_VECTOR[i]);
assertArrayEquals(byteAsIntArray, (int[]) vector);
}

public void testStoredFields_whenVectorIsFloatType_thenSucceed() {
StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForFloatVector(TEST_FIELD_NAME, TEST_FLOAT_VECTOR);
assertEquals(TEST_FIELD_NAME, storedField.name());
BytesRef bytes = new BytesRef(storedField.binaryValue().bytes);
assertArrayEquals(TEST_FLOAT_VECTOR, KNNVectorSerializerFactory.getDefaultSerializer().byteToFloatArray(bytes), 0.001f);

Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.FLOAT);
assertTrue(vector instanceof float[]);
assertArrayEquals(TEST_FLOAT_VECTOR, (float[]) vector, 0.001f);
}

public void testGetExpectedVectorLengthSuccess() {
KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class);
when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3));
KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class);
when(knnVectorFieldTypeBinary.getKnnMappingConfig()).thenReturn(
getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 8)
);
when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY);

assertEquals(3, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType));
assertEquals(1, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeBinary));
}

public void testUseLuceneKNNVectorsFormat_withDifferentInputs_thenSuccess() {
Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_16_0));
Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_17_0));
Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0));
}

/**
* Test useFullFieldNameValidation method for different OpenSearch versions
*/
public void testUseFullFieldNameValidation() {
Assert.assertFalse(KNNVectorFieldMapperUtil.useFullFieldNameValidation(Version.V_2_16_0));
Assert.assertTrue(KNNVectorFieldMapperUtil.useFullFieldNameValidation(Version.V_2_17_0));
Assert.assertTrue(KNNVectorFieldMapperUtil.useFullFieldNameValidation(Version.V_2_18_0));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.mapper;

import java.util.Collections;
import java.util.Optional;

import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import static org.mockito.Mockito.mock;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.ArraySourceValueFetcher;
import org.opensearch.index.mapper.ValueFetcher;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.search.lookup.SearchLookup;

public class KNNVectorFieldTypeTests extends KNNTestCase {

private static final String FIELD_NAME = "test-field";
private static final int DIMENSION = 3;

public void testValueFetcher() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
ValueFetcher valueFetcher = knnVectorFieldType.valueFetcher(mockQueryShardContext, null, null);
assertTrue(valueFetcher instanceof ArraySourceValueFetcher);
}

public void testTypeName() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
assertEquals(KNNVectorFieldMapper.CONTENT_TYPE, knnVectorFieldType.typeName());
}

public void testExistsQuery() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
QueryShardContext mockContext = mock(QueryShardContext.class);
Query query = knnVectorFieldType.existsQuery(mockContext);
assertTrue(query instanceof FieldExistsQuery);
assertEquals(FIELD_NAME, ((FieldExistsQuery) query).getField());
}

public void testTermQuery_throwsException() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
QueryShardContext mockContext = mock(QueryShardContext.class);
expectThrows(QueryShardException.class, () -> knnVectorFieldType.termQuery(new float[] { 1.0f, 2.0f, 3.0f }, mockContext));
}

public void testFielddataBuilder() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
SearchLookup mockSearchLookup = mock(SearchLookup.class);
IndexFieldData.Builder builder = knnVectorFieldType.fielddataBuilder("test-index", () -> mockSearchLookup);
assertTrue(builder instanceof KNNVectorIndexFieldData.Builder);
}

public void testValueForDisplay_whenFloatVector() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
float[] testVector = new float[] { 1.0f, 2.0f, 3.0f };
BytesRef serializedVector = new BytesRef(new byte[testVector.length * Float.BYTES]);
for (int i = 0; i < testVector.length; i++) {
int bits = Float.floatToIntBits(testVector[i]);
int offset = i * Float.BYTES;
serializedVector.bytes[offset] = (byte) (bits & 0xFF);
serializedVector.bytes[offset + 1] = (byte) ((bits >> 8) & 0xFF);
serializedVector.bytes[offset + 2] = (byte) ((bits >> 16) & 0xFF);
serializedVector.bytes[offset + 3] = (byte) ((bits >> 24) & 0xFF);
}
serializedVector.length = testVector.length * Float.BYTES;
Object result = knnVectorFieldType.valueForDisplay(serializedVector);
assertTrue(result instanceof float[]);
assertEquals(testVector.length, ((float[]) result).length);
}

public void testValueForDisplay_whenByteVector() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.BYTE,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
byte[] testVector = new byte[] { 1, 2, 3 };
BytesRef serializedVector = new BytesRef(testVector);
Object result = knnVectorFieldType.valueForDisplay(serializedVector);
assertTrue(result instanceof int[]);
}

public void testResolveRescoreContext_whenUserProvidedContext() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
RescoreContext userContext = RescoreContext.builder().oversampleFactor(2.5f).userProvided(true).build();
RescoreContext result = knnVectorFieldType.resolveRescoreContext(userContext);
assertSame(userContext, result);
}

public void testResolveRescoreContext_whenNullContext() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNMappingConfig mappingConfig = new KNNMappingConfig() {
@Override
public int getDimension() {
return DIMENSION;
}

@Override
public Optional<KNNMethodContext> getKnnMethodContext() {
return Optional.of(knnMethodContext);
}

@Override
public CompressionLevel getCompressionLevel() {
return CompressionLevel.x32;
}

@Override
public Mode getMode() {
return Mode.ON_DISK;
}
};
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
mappingConfig
);
RescoreContext result = knnVectorFieldType.resolveRescoreContext(null);
assertNotNull(result);
}

public void testTransformQueryVector_whenFloatVector() {
KNNMethodContext knnMethodContext = getDefaultKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
float[] queryVector = new float[] { 3.0f, 4.0f, 0.0f };
knnVectorFieldType.transformQueryVector(queryVector);
assertNotNull(queryVector);
assertEquals(3, queryVector.length);
}

public void testTransformQueryVector_whenByteVector() {
KNNMethodContext knnMethodContext = getDefaultByteKNNMethodContext();
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.BYTE,
getMappingConfigForMethodMapping(knnMethodContext, DIMENSION)
);
float[] queryVector = new float[] { 1.0f, 2.0f, 3.0f };
float[] originalVector = queryVector.clone();
knnVectorFieldType.transformQueryVector(queryVector);
assertArrayEquals(originalVector, queryVector, 0.0001f);
}

public void testTransformQueryVector_whenNoMethodContext_throwsException() {
KNNMappingConfig mappingConfig = getMappingConfigForFlatMapping(DIMENSION);
KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType(
FIELD_NAME,
Collections.emptyMap(),
VectorDataType.FLOAT,
mappingConfig
);
float[] queryVector = new float[] { 1.0f, 2.0f, 3.0f };
expectThrows(IllegalStateException.class, () -> knnVectorFieldType.transformQueryVector(queryVector));
}
}
Loading