Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.BytesRef;
import org.junit.Assert;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.Version;
Expand Down Expand Up @@ -193,6 +194,156 @@ public void testTypeParser_build_fromKnnMethodContext() throws IOException {
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty());
}

public void testTypeParser_withDifferentSpaceTypeCombinations_thenSuccess() throws IOException {
// Check that knnMethodContext takes precedent over both model and legacy
ModelDao modelDao = mock(ModelDao.class);
int mForSetting = 71;
// Setup settings
Settings settings = Settings.builder()
.put(settings(CURRENT).build())
.put(KNNSettings.KNN_ALGO_PARAM_M, mForSetting)
.put(KNN_INDEX, true)
.build();
SpaceType methodSpaceType = SpaceType.COSINESIMIL;
SpaceType topLevelSpaceType = SpaceType.INNER_PRODUCT;
KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao);

// space type provided at top level but not in the method
XContentBuilder xContentBuilder = createXContentForFieldMapping(topLevelSpaceType, null, null, null, TEST_DIMENSION);

KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse(
"test-field-name-1",
xContentBuilderToMap(xContentBuilder),
buildParserContext("test", settings)
);

Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper);
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent());
assertEquals(topLevelSpaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType());
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty());

// not setting any space type
xContentBuilder = createXContentForFieldMapping(null, null, null, null, TEST_DIMENSION);

builder = (KNNVectorFieldMapper.Builder) typeParser.parse(
"test-field-name-1",
xContentBuilderToMap(xContentBuilder),
buildParserContext("test", settings)
);

builderContext = new Mapper.BuilderContext(settings, new ContentPath());
knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper);
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent());
assertEquals(SpaceType.DEFAULT, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType());
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty());

// if space types are same
xContentBuilder = createXContentForFieldMapping(topLevelSpaceType, topLevelSpaceType, null, null, TEST_DIMENSION);
builder = (KNNVectorFieldMapper.Builder) typeParser.parse(
"test-field-name-1",
xContentBuilderToMap(xContentBuilder),
buildParserContext("test", settings)
);

builderContext = new Mapper.BuilderContext(settings, new ContentPath());
knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper);
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent());
assertEquals(topLevelSpaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType());
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty());

// if space types are not same
xContentBuilder = createXContentForFieldMapping(topLevelSpaceType, methodSpaceType, null, null, TEST_DIMENSION);

XContentBuilder finalXContentBuilder = xContentBuilder;
Assert.assertThrows(
MapperParsingException.class,
() -> typeParser.parse("test-field-name-1", xContentBuilderToMap(finalXContentBuilder), buildParserContext("test", settings))
);

// if space types not provided and field is binary
xContentBuilder = createXContentForFieldMapping(null, null, KNNEngine.FAISS, VectorDataType.BINARY, 8);
builder = (KNNVectorFieldMapper.Builder) typeParser.parse(
"test-field-name-1",
xContentBuilderToMap(xContentBuilder),
buildParserContext("test", settings)
);

builderContext = new Mapper.BuilderContext(settings, new ContentPath());
knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper);
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent());
assertEquals(
SpaceType.DEFAULT_BINARY,
knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()
);
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty());

// if space type is provided and legacy mappings is hit
xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, TEST_DIMENSION)
.field(KNNConstants.TOP_LEVEL_PARAMETER_SPACE_TYPE, topLevelSpaceType.getValue())
.endObject();
builder = (KNNVectorFieldMapper.Builder) typeParser.parse(
"test-field-name-1",
xContentBuilderToMap(xContentBuilder),
buildParserContext("test", settings)
);

builderContext = new Mapper.BuilderContext(settings, new ContentPath());
knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper);
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent());
assertEquals(topLevelSpaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType());
// this check ensures that legacy mapping is hit, as in legacy mapping we pick M from index settings
assertEquals(
mForSetting,
knnVectorFieldMapper.fieldType()
.getKnnMappingConfig()
.getKnnMethodContext()
.get()
.getMethodComponentContext()
.getParameters()
.get(METHOD_PARAMETER_M)
);
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty());
}

public void testTypeParser_withSpaceTypeAndMode_thenSuccess() throws IOException {
// Check that knnMethodContext takes precedent over both model and legacy
ModelDao modelDao = mock(ModelDao.class);
// Setup settings
Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build();

SpaceType topLevelSpaceType = SpaceType.INNER_PRODUCT;
KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao);
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION, TEST_DIMENSION)
.field(MODE_PARAMETER, Mode.ON_DISK.getName())
.field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x16.getName())
.field(KNNConstants.TOP_LEVEL_PARAMETER_SPACE_TYPE, topLevelSpaceType.getValue())
.endObject();
KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse(
"test-field-name-1",
xContentBuilderToMap(xContentBuilder),
buildParserContext("test", settings)
);

Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper);
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent());
assertEquals(topLevelSpaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType());
assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty());
}

public void testBuilder_build_fromModel() {
// Check that modelContext takes precedent over legacy
ModelDao modelDao = mock(ModelDao.class);
Expand Down Expand Up @@ -1445,6 +1596,35 @@ private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperIn
.originalKnnMethodContext(getDefaultKNNMethodContext());
}

private XContentBuilder createXContentForFieldMapping(
SpaceType topLevelSpaceType,
SpaceType methodSpaceType,
KNNEngine knnEngine,
VectorDataType vectorDataType,
int dimension
) throws IOException {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, dimension);

if (topLevelSpaceType != null && topLevelSpaceType != SpaceType.UNDEFINED) {
xContentBuilder.field(KNNConstants.TOP_LEVEL_PARAMETER_SPACE_TYPE, topLevelSpaceType.getValue());
}
if (vectorDataType != null) {
xContentBuilder.field(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue());
}
xContentBuilder.startObject(KNN_METHOD).field(NAME, METHOD_HNSW);
if (knnEngine != null) {
xContentBuilder.field(KNN_ENGINE, knnEngine.getName());
}
if (methodSpaceType != null && methodSpaceType != SpaceType.UNDEFINED) {
xContentBuilder.field(METHOD_PARAMETER_SPACE_TYPE, methodSpaceType.getValue());
}
xContentBuilder.endObject().endObject();
return xContentBuilder;
}

private static float[] createInitializedFloatArray(int dimension, float value) {
float[] array = new float[dimension];
Arrays.fill(array, value);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.integ;

import lombok.SneakyThrows;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;

import java.io.IOException;

import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.NAME;

public class TopLevelSpaceTypeParameterIT extends KNNRestTestCase {
private final static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f };
private final static int DIMENSION = 2;
private final static int K = 1;
private static final String INDEX_NAME = "top-level-space-type-index";

@SneakyThrows
public void testBaseCase() {
createTestIndexWithTopLevelSpaceTypeOnly();
addKnnDoc(INDEX_NAME, "0", FIELD_NAME, TEST_VECTOR);
validateKNNSearch(INDEX_NAME, FIELD_NAME, DIMENSION, 1, K);
deleteIndex(INDEX_NAME);

createTestIndexWithTopLevelSpaceTypeAndMethodSpaceType();
addKnnDoc(INDEX_NAME, "0", FIELD_NAME, TEST_VECTOR);
validateKNNSearch(INDEX_NAME, FIELD_NAME, DIMENSION, 1, K);
deleteIndex(INDEX_NAME);

createTestIndexWithNoSpaceType();
addKnnDoc(INDEX_NAME, "0", FIELD_NAME, TEST_VECTOR);
validateKNNSearch(INDEX_NAME, FIELD_NAME, DIMENSION, 1, K);
deleteIndex(INDEX_NAME);
}

private void createTestIndexWithTopLevelSpaceTypeOnly() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", DIMENSION)
.field(KNNConstants.TOP_LEVEL_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue())
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(KNN_ENGINE, FAISS_NAME)
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = builder.toString();
createKnnIndex(INDEX_NAME, mapping);
}

private void createTestIndexWithTopLevelSpaceTypeAndMethodSpaceType() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", DIMENSION)
.field(KNNConstants.TOP_LEVEL_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue())
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(KNN_ENGINE, FAISS_NAME)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue())
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = builder.toString();
createKnnIndex(INDEX_NAME, mapping);
}

private void createTestIndexWithNoSpaceType() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", DIMENSION)
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(KNN_ENGINE, FAISS_NAME)
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = builder.toString();
createKnnIndex(INDEX_NAME, mapping);
}
}