diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 4c1dc6e645..b9319a434f 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -71,6 +71,7 @@ public class KNNConstants { public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD; public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; + public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature"; public static final String RADIAL_SEARCH_KEY = "radial_search"; @@ -149,4 +150,7 @@ public class KNNConstants { public static final Float DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO = 0.95f; public static final String MIN_SCORE = "min_score"; public static final String MAX_DISTANCE = "max_distance"; + + public static final String MODE_PARAMETER = "mode"; + public static final String COMPRESSION_LEVEL_PARAMETER = "compression_level"; } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java index 731085f0ba..1ba2777dd7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java @@ -7,10 +7,9 @@ import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; -import org.apache.commons.lang.builder.EqualsBuilder; -import org.apache.commons.lang.builder.HashCodeBuilder; import org.opensearch.Version; import org.opensearch.knn.index.VectorDataType; @@ -23,29 +22,10 @@ @Getter @Builder @AllArgsConstructor +@EqualsAndHashCode public final class KNNMethodConfigContext { private VectorDataType vectorDataType; private Integer dimension; private Version versionCreated; - - @Override - public boolean equals(Object obj) { - if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; - KNNMethodConfigContext other = (KNNMethodConfigContext) obj; - - EqualsBuilder equalsBuilder = new EqualsBuilder(); - equalsBuilder.append(vectorDataType, other.vectorDataType); - equalsBuilder.append(dimension, other.dimension); - equalsBuilder.append(versionCreated, other.versionCreated); - - return equalsBuilder.isEquals(); - } - - @Override - public int hashCode() { - return new HashCodeBuilder().append(vectorDataType).append(dimension).append(versionCreated).toHashCode(); - } - public static final KNNMethodConfigContext EMPTY = KNNMethodConfigContext.builder().build(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java new file mode 100644 index 0000000000..b5ce81af98 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.opensearch.core.common.Strings; + +import java.util.Arrays; +import java.util.Locale; +import java.util.stream.Collectors; + +/** + * Enum representing the compression level for float vectors. Compression in this sense refers to compressing a + * full precision value into a smaller number of bits. For instance. "16x" compression would mean that 2 bits would + * need to be used to represent a 32-bit floating point number. + */ +@AllArgsConstructor +public enum CompressionLevel { + NOT_CONFIGURED(-1, ""), + x1(1, "1x"), + x2(2, "2x"), + x4(4, "4x"), + x8(8, "8x"), + x16(16, "16x"), + x32(32, "32x"); + + // Internally, an empty string is easier to deal with them null. However, from the mapping, + // we do not want users to pass in the empty string and instead want null. So we make the conversion herex + static final String[] NAMES_ARRAY = Arrays.stream(CompressionLevel.values()) + .map(compressionLevel -> compressionLevel == NOT_CONFIGURED ? null : compressionLevel.getName()) + .collect(Collectors.toList()) + .toArray(new String[0]); + + /** + * Default is set to 1x and is a noop + */ + private static final CompressionLevel DEFAULT = x1; + + /** + * Get the compression level from a string representation. The format for the string should be "Nx", where N is + * the factor by which compression should take place + * + * @param name String representation of the compression level + * @return CompressionLevel enum value + */ + public static CompressionLevel fromName(String name) { + if (Strings.isEmpty(name)) { + return NOT_CONFIGURED; + } + for (CompressionLevel config : CompressionLevel.values()) { + if (config.getName() != null && config.getName().equals(name)) { + return config; + } + } + throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid compression level: \"[%s]\"", name)); + } + + private final int compressionLevel; + @Getter + private final String name; + + /** + * Gets the number of bits used to represent a float in order to achieve this compression. For instance, for + * 32x compression, each float would need to be encoded in a single bit. + * + * @return number of bits to represent a float at this compression level + */ + public int numBitsForFloat32() { + if (this == NOT_CONFIGURED) { + return DEFAULT.numBitsForFloat32(); + } + + return (Float.BYTES * Byte.SIZE) / compressionLevel; + } + + /** + * Utility method that checks if compression is configured. + * + * @param compressionLevel Compression to check + * @return true if compression is configured, false otherwise + */ + public static boolean isConfigured(CompressionLevel compressionLevel) { + return compressionLevel != null && compressionLevel != NOT_CONFIGURED; + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index d37ab9b869..8da41aa599 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -31,7 +31,8 @@ public static FlatVectorFieldMapper createFieldMapper( CopyTo copyTo, Explicit ignoreMalformed, boolean stored, - boolean hasDocValues + boolean hasDocValues, + OriginalMappingParameters originalMappingParameters ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -47,7 +48,8 @@ public static FlatVectorFieldMapper createFieldMapper( ignoreMalformed, stored, hasDocValues, - knnMethodConfigContext.getVersionCreated() + knnMethodConfigContext.getVersionCreated(), + originalMappingParameters ); } @@ -59,9 +61,20 @@ private FlatVectorFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - Version indexCreatedVersion + Version indexCreatedVersion, + OriginalMappingParameters originalMappingParameters ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + super( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + indexCreatedVersion, + originalMappingParameters + ); // setting it explicitly false here to ensure that when flatmapper is used Lucene based Vector field is not created. this.useLuceneBasedVectorField = false; this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 65c3cfb660..0eab5a7bb4 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -104,13 +104,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { } return value; }, - m -> { - KNNMappingConfig knnMappingConfig = toType(m).fieldType().getKnnMappingConfig(); - if (knnMappingConfig.getModelId().isPresent()) { - return UNSET_MODEL_DIMENSION_IDENTIFIER; - } - return knnMappingConfig.getDimension(); - } + m -> toType(m).originalMappingParameters.getDimension() ); /** @@ -122,7 +116,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { false, () -> DEFAULT_VECTOR_DATA_TYPE_FIELD, (n, c, o) -> VectorDataType.get((String) o), - m -> toType(m).vectorDataType + m -> toType(m).originalMappingParameters.getVectorDataType() ); /** @@ -133,7 +127,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected final Parameter modelId = Parameter.stringParam( KNNConstants.MODEL_ID, false, - m -> toType(m).fieldType().getKnnMappingConfig().getModelId().orElse(null), + m -> toType(m).originalMappingParameters.getModelId(), null ); @@ -146,7 +140,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { false, () -> null, (n, c, o) -> KNNMethodContext.parse(o), - m -> toType(m).originalKNNMethodContext + m -> toType(m).originalMappingParameters.getKnnMethodContext() ).setSerializer(((b, n, v) -> { b.startObject(n); v.toXContent(b, ToXContent.EMPTY_PARAMS); @@ -162,48 +156,47 @@ public static class Builder extends ParametrizedFieldMapper.Builder { } }); + protected final Parameter mode = Parameter.restrictedStringParam( + KNNConstants.MODE_PARAMETER, + false, + m -> toType(m).originalMappingParameters.getMode(), + Mode.NAMES_ARRAY + ).acceptsNull(); + + protected final Parameter compressionLevel = Parameter.restrictedStringParam( + KNNConstants.COMPRESSION_LEVEL_PARAMETER, + false, + m -> toType(m).originalMappingParameters.getCompressionLevel(), + CompressionLevel.NAMES_ARRAY + ).acceptsNull(); + protected final Parameter> meta = Parameter.metaParam(); protected ModelDao modelDao; protected Version indexCreatedVersion; - // KNNMethodContext that allows us to properly configure a KNNVectorFieldMapper from another - // KNNVectorFieldMapper. To support our legacy field mapping, on parsing, if index.knn=true and no method is - // passed, we build a KNNMethodContext using the space type, ef_construction and m that are set in the index - // settings. However, for fieldmappers for merging, we need to be able to initialize one field mapper from - // another (see - // https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L98). - // The problem is that in this case, the settings are set to empty so we cannot properly resolve the KNNMethodContext. - // (see - // https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L130). - // While we could override the KNNMethodContext parameter initializer to set the knnMethodContext based on the - // constructed KNNMethodContext from the other field mapper, this can result in merge conflict/serialization - // exceptions. See - // (https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L322-L324). - // So, what we do is pass in a "resolvedKNNMethodContext" that will either be null or be set via the merge builder - // constructor. A similar approach was taken for https://github.com/opendistro-for-elasticsearch/k-NN/issues/288 - @Setter - @Getter - private KNNMethodContext resolvedKNNMethodContext; @Setter private KNNMethodConfigContext knnMethodConfigContext; + @Setter + @Getter + private OriginalMappingParameters originalParameters; public Builder( String name, ModelDao modelDao, Version indexCreatedVersion, - KNNMethodContext resolvedKNNMethodContext, - KNNMethodConfigContext knnMethodConfigContext + KNNMethodConfigContext knnMethodConfigContext, + OriginalMappingParameters originalParameters ) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; - this.resolvedKNNMethodContext = resolvedKNNMethodContext; this.knnMethodConfigContext = knnMethodConfigContext; + this.originalParameters = originalParameters; } @Override protected List> getParameters() { - return Arrays.asList(stored, hasDocValues, dimension, vectorDataType, meta, knnMethodContext, modelId); + return Arrays.asList(stored, hasDocValues, dimension, vectorDataType, meta, knnMethodContext, modelId, mode, compressionLevel); } protected Explicit ignoreMalformed(BuilderContext context) { @@ -231,18 +224,18 @@ public KNNVectorFieldMapper build(BuilderContext context) { name, metaValue, vectorDataType.getValue(), - modelId.get(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.get(), hasDocValues.get(), modelDao, - indexCreatedVersion + indexCreatedVersion, + originalParameters ); } - if (resolvedKNNMethodContext == null) { + if (originalParameters.getResolvedKnnMethodContext() == null) { return FlatVectorFieldMapper.createFieldMapper( buildFullName(context), name, @@ -256,11 +249,12 @@ public KNNVectorFieldMapper build(BuilderContext context) { copyToBuilder, ignoreMalformed, stored.get(), - hasDocValues.get() + hasDocValues.get(), + originalParameters ); } - if (resolvedKNNMethodContext.getKnnEngine() == KNNEngine.LUCENE) { + if (originalParameters.getResolvedKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE) { log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = LuceneFieldMapper.CreateLuceneFieldMapperInput .builder() @@ -275,9 +269,9 @@ public KNNVectorFieldMapper build(BuilderContext context) { return LuceneFieldMapper.createFieldMapper( buildFullName(context), metaValue, - resolvedKNNMethodContext, knnMethodConfigContext, - createLuceneFieldMapperInput + createLuceneFieldMapperInput, + originalParameters ); } @@ -285,14 +279,13 @@ public KNNVectorFieldMapper build(BuilderContext context) { buildFullName(context), name, metaValue, - resolvedKNNMethodContext, knnMethodConfigContext, - knnMethodContext.get(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.getValue(), - hasDocValues.getValue() + hasDocValues.getValue(), + originalParameters ); } @@ -343,6 +336,7 @@ public Mapper.Builder parse(String name, Map node, ParserCont null ); builder.parse(name, parserContext, node); + builder.setOriginalParameters(new OriginalMappingParameters(builder)); // All parsing @@ -372,6 +366,7 @@ private void validateFromFlat(KNNVectorFieldMapper.Builder builder) { throw new IllegalArgumentException("Cannot set modelId or method parameters when index.knn setting is false"); } validateDimensionSet(builder); + validateCompressionAndModeNotSet(builder, builder.name(), "flat"); } private void validateFromModel(KNNVectorFieldMapper.Builder builder) { @@ -379,11 +374,13 @@ private void validateFromModel(KNNVectorFieldMapper.Builder builder) { if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER && builder.modelId.get() == null) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", builder.name())); } + validateCompressionAndModeNotSet(builder, builder.name(), "model"); } private void validateFromKNNMethod(KNNVectorFieldMapper.Builder builder) { - if (builder.resolvedKNNMethodContext != null) { - ValidationException validationException = builder.resolvedKNNMethodContext.validate(builder.knnMethodConfigContext); + if (builder.originalParameters.getResolvedKnnMethodContext() != null) { + ValidationException validationException = builder.originalParameters.getResolvedKnnMethodContext() + .validate(builder.knnMethodConfigContext); if (validationException != null) { throw validationException; } @@ -397,6 +394,19 @@ private void validateDimensionSet(KNNVectorFieldMapper.Builder builder) { } } + private void validateCompressionAndModeNotSet(KNNVectorFieldMapper.Builder builder, String name, String context) { + if (builder.mode.isConfigured() || builder.compressionLevel.isConfigured()) { + throw new MapperParsingException( + String.format( + Locale.ROOT, + "Compression and mode can not be specified in a %s mapping configuration for field: %s", + context, + name + ) + ); + } + } + private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, ParserContext parserContext) { builder.setKnnMethodConfigContext( KNNMethodConfigContext.builder() @@ -407,13 +417,12 @@ private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, Pa ); // Configure method from map or legacy - builder.setResolvedKNNMethodContext( - builder.knnMethodContext.getValue() != null - ? builder.knnMethodContext.getValue() - : createKNNMethodContextFromLegacy(parserContext.getSettings(), parserContext.indexVersionCreated()) - ); - // TODO: We should remove this and set it based on the KNNMethodContext - setDefaultSpaceType(builder.resolvedKNNMethodContext, builder.vectorDataType.getValue()); + if (builder.originalParameters.isLegacyMapping()) { + builder.originalParameters.setResolvedKnnMethodContext( + createKNNMethodContextFromLegacy(parserContext.getSettings(), parserContext.indexVersionCreated()) + ); + } + setDefaultSpaceType(builder.originalParameters.getResolvedKnnMethodContext(), builder.vectorDataType.getValue()); } private boolean isKNNDisabled(Settings settings) { @@ -449,7 +458,7 @@ private void setDefaultSpaceType(final KNNMethodContext knnMethodContext, final // We need to ensure that the original KNNMethodContext as parsed is stored to initialize the // Builder for serialization. So, we need to store it here. This is mainly to ensure that the legacy field mapper // can use KNNMethodContext without messing up serialization on mapper merge - protected KNNMethodContext originalKNNMethodContext; + protected OriginalMappingParameters originalMappingParameters; public KNNVectorFieldMapper( String simpleName, @@ -460,7 +469,7 @@ public KNNVectorFieldMapper( boolean stored, boolean hasDocValues, Version indexCreatedVersion, - KNNMethodContext originalKNNMethodContext + OriginalMappingParameters originalMappingParameters ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; @@ -469,7 +478,7 @@ public KNNVectorFieldMapper( this.vectorDataType = mappedFieldType.getVectorDataType(); updateEngineStats(); this.indexCreatedVersion = indexCreatedVersion; - this.originalKNNMethodContext = originalKNNMethodContext; + this.originalMappingParameters = originalMappingParameters; } public KNNVectorFieldMapper clone() { @@ -680,8 +689,8 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() { simpleName(), modelDao, indexCreatedVersion, - fieldType().getKnnMappingConfig().getKnnMethodContext().orElse(null), - knnMethodConfigContext + knnMethodConfigContext, + originalMappingParameters ).init(this); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 744ba4bd53..3da2745acd 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -45,9 +45,9 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { static LuceneFieldMapper createFieldMapper( String fullname, Map metaValue, - KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext, - CreateLuceneFieldMapperInput createLuceneFieldMapperInput + CreateLuceneFieldMapperInput createLuceneFieldMapperInput, + OriginalMappingParameters originalMappingParameters ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -56,7 +56,7 @@ static LuceneFieldMapper createFieldMapper( new KNNMappingConfig() { @Override public Optional getKnnMethodContext() { - return Optional.of(knnMethodContext); + return Optional.of(originalMappingParameters.getResolvedKnnMethodContext()); } @Override @@ -66,13 +66,14 @@ public int getDimension() { } ); - return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput, knnMethodConfigContext); + return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput, knnMethodConfigContext, originalMappingParameters); } private LuceneFieldMapper( final KNNVectorFieldType mappedFieldType, final CreateLuceneFieldMapperInput input, - KNNMethodConfigContext knnMethodConfigContext + KNNMethodConfigContext knnMethodConfigContext, + OriginalMappingParameters originalMappingParameters ) { super( input.getName(), @@ -83,7 +84,7 @@ private LuceneFieldMapper( input.isStored(), input.isHasDocValues(), knnMethodConfigContext.getVersionCreated(), - mappedFieldType.knnMappingConfig.getKnnMethodContext().orElse(null) + originalMappingParameters ); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 90d4ca879f..f1a87c64bd 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -43,14 +43,13 @@ public static MethodFieldMapper createFieldMapper( String fullname, String simpleName, Map metaValue, - KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext, - KNNMethodContext originalKNNMethodContext, MultiFields multiFields, CopyTo copyTo, Explicit ignoreMalformed, boolean stored, - boolean hasDocValues + boolean hasDocValues, + OriginalMappingParameters originalMappingParameters ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( fullname, @@ -59,7 +58,7 @@ public static MethodFieldMapper createFieldMapper( new KNNMappingConfig() { @Override public Optional getKnnMethodContext() { - return Optional.of(knnMethodContext); + return Optional.of(originalMappingParameters.getResolvedKnnMethodContext()); } @Override @@ -76,8 +75,8 @@ public int getDimension() { ignoreMalformed, stored, hasDocValues, - originalKNNMethodContext, - knnMethodConfigContext + knnMethodConfigContext, + originalMappingParameters ); } @@ -89,8 +88,8 @@ private MethodFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - KNNMethodContext originalKNNMethodContext, - KNNMethodConfigContext knnMethodConfigContext + KNNMethodConfigContext knnMethodConfigContext, + OriginalMappingParameters originalMappingParameters ) { super( @@ -102,7 +101,7 @@ private MethodFieldMapper( stored, hasDocValues, knnMethodConfigContext.getVersionCreated(), - originalKNNMethodContext + originalMappingParameters ); this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/Mode.java b/src/main/java/org/opensearch/knn/index/mapper/Mode.java new file mode 100644 index 0000000000..0798ab9419 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/Mode.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.opensearch.core.common.Strings; + +import java.util.Arrays; +import java.util.Locale; +import java.util.stream.Collectors; + +/** + * Enum representing the intended workload optimization a user wants their k-NN system to have. Based on this value, + * default parameter resolution will be determined. + */ +@Getter +@AllArgsConstructor +public enum Mode { + NOT_CONFIGURED(""), + IN_MEMORY("in_memory"), + ON_DISK("on_disk"); + + // Internally, an empty string is easier to deal with them null. However, from the mapping, + // we do not want users to pass in the empty string and instead want null. So we make the conversion herex + static final String[] NAMES_ARRAY = Arrays.stream(Mode.values()) + .map(mode -> mode == NOT_CONFIGURED ? null : mode.getName()) + .collect(Collectors.toList()) + .toArray(new String[0]); + + private static final Mode DEFAULT = IN_MEMORY; + + /** + * Convert a string to a Mode enum value + * + * @param name String value to convert + * @return Mode enum value + */ + public static Mode fromName(String name) { + if (Strings.isEmpty(name)) { + return NOT_CONFIGURED; + } + + if (IN_MEMORY.name.equalsIgnoreCase(name)) { + return IN_MEMORY; + } + + if (ON_DISK.name.equalsIgnoreCase(name)) { + return ON_DISK; + } + throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid mode: \"[%s]\"", name)); + } + + private final String name; + + /** + * Utility method that checks if mode is configured. + * + * @param mode Mode to check + * @return true if mode is configured, false otherwise + */ + public static boolean isConfigured(Mode mode) { + return mode != null && mode != NOT_CONFIGURED; + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index b29466eefc..bfb188a754 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -49,25 +49,25 @@ public static ModelFieldMapper createFieldMapper( String simpleName, Map metaValue, VectorDataType vectorDataType, - String modelId, MultiFields multiFields, CopyTo copyTo, Explicit ignoreMalformed, boolean stored, boolean hasDocValues, ModelDao modelDao, - Version indexCreatedVersion + Version indexCreatedVersion, + OriginalMappingParameters originalMappingParameters ) { final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { @Override public Optional getModelId() { - return Optional.of(modelId); + return Optional.of(originalMappingParameters.getModelId()); } @Override public int getDimension() { - return getModelMetadata(modelDao, modelId).getDimension(); + return getModelMetadata(modelDao, originalMappingParameters.getModelId()).getDimension(); } }); return new ModelFieldMapper( @@ -79,7 +79,8 @@ public int getDimension() { stored, hasDocValues, modelDao, - indexCreatedVersion + indexCreatedVersion, + originalMappingParameters ); } @@ -92,9 +93,20 @@ private ModelFieldMapper( boolean stored, boolean hasDocValues, ModelDao modelDao, - Version indexCreatedVersion + Version indexCreatedVersion, + OriginalMappingParameters originalMappingParameters ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + super( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + indexCreatedVersion, + originalMappingParameters + ); KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); this.modelDao = modelDao; diff --git a/src/main/java/org/opensearch/knn/index/mapper/OriginalMappingParameters.java b/src/main/java/org/opensearch/knn/index/mapper/OriginalMappingParameters.java new file mode 100644 index 0000000000..f7235620f6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/OriginalMappingParameters.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import org.opensearch.core.common.Strings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodContext; + +/** + * Utility class to store the original mapping parameters for a KNNVectorFieldMapper. These parameters need to be + * kept around for when a {@link KNNVectorFieldMapper} is built from merge + */ +@Getter +@RequiredArgsConstructor +public final class OriginalMappingParameters { + private final VectorDataType vectorDataType; + private final int dimension; + private final KNNMethodContext knnMethodContext; + + // To support our legacy field mapping, on parsing, if index.knn=true and no method is + // passed, we build a KNNMethodContext using the space type, ef_construction and m that are set in the index + // settings. However, for fieldmappers for merging, we need to be able to initialize one field mapper from + // another (see + // https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L98). + // The problem is that in this case, the settings are set to empty so we cannot properly resolve the KNNMethodContext. + // (see + // https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L130). + // While we could override the KNNMethodContext parameter initializer to set the knnMethodContext based on the + // constructed KNNMethodContext from the other field mapper, this can result in merge conflict/serialization + // exceptions. See + // (https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L322-L324). + // So, what we do is pass in a "resolvedKNNMethodContext" to ensure we track this resolveKnnMethodContext. + // A similar approach was taken for https://github.com/opendistro-for-elasticsearch/k-NN/issues/288 + @Setter + private KNNMethodContext resolvedKnnMethodContext; + private final String mode; + private final String compressionLevel; + private final String modelId; + + /** + * Initialize the parameters from the builder + * + * @param builder The builder to initialize from + */ + public OriginalMappingParameters(KNNVectorFieldMapper.Builder builder) { + this.vectorDataType = builder.vectorDataType.get(); + this.knnMethodContext = builder.knnMethodContext.get(); + this.resolvedKnnMethodContext = builder.knnMethodContext.get(); + this.dimension = builder.dimension.get(); + this.mode = builder.mode.get(); + this.compressionLevel = builder.compressionLevel.get(); + this.modelId = builder.modelId.get(); + } + + /** + * Determine if the mapping used the legacy mechanism to setup the index. The legacy mechanism is used if + * the index is created only by specifying the dimension. If this is the case, the constructed parameters + * need to be collected from the index settings + * + * @return true if the mapping used the legacy mechanism, false otherwise + */ + public boolean isLegacyMapping() { + if (knnMethodContext != null) { + return false; + } + + if (modelId != null) { + return false; + } + + return Strings.isEmpty(mode) && Strings.isEmpty(compressionLevel); + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java index f0d57ddeb7..b548a9fd79 100644 --- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java @@ -52,6 +52,7 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0; private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0; + private static final Version MINIMAL_MODE_AND_COMPRESSION_FEATURE = Version.V_2_17_0; // public so neural search can access it public static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); public static final Set VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS = Set.of(VectorDataType.BINARY, VectorDataType.BYTE); @@ -390,6 +391,7 @@ private static Map initializeMinimalRequiredVersionMap() { put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS); put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE); put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE); + put(KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE, MINIMAL_MODE_AND_COMPRESSION_FEATURE); } }; diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index e955966990..326d595a47 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -51,6 +51,8 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; @@ -293,7 +295,12 @@ private void putInternal(Model model, ActionListener listener, Do put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()); put(KNNConstants.VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType()); - + if (Mode.isConfigured(modelMetadata.getMode())) { + put(KNNConstants.MODE_PARAMETER, modelMetadata.getMode().getName()); + } + if (CompressionLevel.isConfigured(modelMetadata.getCompressionLevel())) { + put(KNNConstants.COMPRESSION_LEVEL_PARAMETER, modelMetadata.getCompressionLevel().getName()); + } MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (!methodComponentContext.getName().isEmpty()) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index 60301e244a..620e520ba8 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -11,6 +11,7 @@ package org.opensearch.knn.indices; +import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; @@ -23,6 +24,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -51,7 +54,11 @@ public class ModelMetadata implements Writeable, ToXContentObject { final private String trainingNodeAssignment; final private VectorDataType vectorDataType; private MethodComponentContext methodComponentContext; + @Getter + private final Mode mode; private String error; + @Getter + private final CompressionLevel compressionLevel; /** * Constructor @@ -89,6 +96,15 @@ public ModelMetadata(StreamInput in) throws IOException { } else { this.vectorDataType = VectorDataType.DEFAULT; } + + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) { + this.mode = Mode.fromName(in.readOptionalString()); + this.compressionLevel = CompressionLevel.fromName(in.readOptionalString()); + } else { + this.mode = Mode.NOT_CONFIGURED; + this.compressionLevel = CompressionLevel.NOT_CONFIGURED; + } + } /** @@ -115,7 +131,9 @@ public ModelMetadata( String error, String trainingNodeAssignment, MethodComponentContext methodComponentContext, - VectorDataType vectorDataType + VectorDataType vectorDataType, + Mode mode, + CompressionLevel compressionLevel ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -139,6 +157,8 @@ public ModelMetadata( this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null"); this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null"); + this.mode = Objects.requireNonNull(mode, "Mode must not be null"); + this.compressionLevel = Objects.requireNonNull(compressionLevel, "Compression level must not be null"); } /** @@ -257,7 +277,9 @@ public String toString() { error, trainingNodeAssignment, methodComponentContext.toClusterStateString(), - vectorDataType.getValue() + vectorDataType.getValue(), + mode.getName(), + compressionLevel.getName() ); } @@ -276,6 +298,8 @@ public boolean equals(Object obj) { equalsBuilder.append(getDescription(), other.getDescription()); equalsBuilder.append(getError(), other.getError()); equalsBuilder.append(getVectorDataType(), other.getVectorDataType()); + equalsBuilder.append(getMode(), other.getMode()); + equalsBuilder.append(getCompressionLevel(), other.getCompressionLevel()); return equalsBuilder.isEquals(); } @@ -291,6 +315,8 @@ public int hashCode() { .append(getError()) .append(getMethodComponentContext()) .append(getVectorDataType()) + .append(getMode()) + .append(getCompressionLevel()) .toHashCode(); } @@ -304,13 +330,14 @@ public static ModelMetadata fromString(String modelMetadataString) { String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1); int length = modelMetadataArray.length; - if (length < 7 || length > 10) { + if (length < 7 || length > 12) { throw new IllegalArgumentException( "Illegal format for model metadata. Must be of the form " + "\",,,,,,\" or " + "\",,,,,,,\" or " + "\",,,,,,,,\" or " - + "\",,,,,,,,,\"." + + "\",,,,,,,,,\". or " + + "\",,,,,,,,,,,\"." ); } @@ -326,6 +353,10 @@ public static ModelMetadata fromString(String modelMetadataString) { ? MethodComponentContext.fromClusterStateString(modelMetadataArray[8]) : MethodComponentContext.EMPTY; VectorDataType vectorDataType = length > 9 ? VectorDataType.get(modelMetadataArray[9]) : VectorDataType.DEFAULT; + Mode mode = length > 10 ? Mode.fromName(modelMetadataArray[10]) : Mode.NOT_CONFIGURED; + CompressionLevel compressionLevel = length > 11 + ? CompressionLevel.fromName(modelMetadataArray[11]) + : CompressionLevel.NOT_CONFIGURED; log.debug(getLogMessage(length)); @@ -339,7 +370,9 @@ public static ModelMetadata fromString(String modelMetadataString) { error, trainingNodeAssignment, methodComponentContext, - vectorDataType + vectorDataType, + mode, + compressionLevel ); } @@ -353,6 +386,9 @@ private static String getLogMessage(int length) { return "Model metadata contains training node assignment and method context."; case 10: return "Model metadata contains training node assignment, method context and vector data type."; + case 11: + case 12: + return "Model metadata contains mode and compression level"; default: throw new IllegalArgumentException("Unexpected metadata array length: " + length); } @@ -385,6 +421,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT); Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT); Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD); + Object mode = modelSourceMap.get(KNNConstants.MODE_PARAMETER); + Object compressionLevel = modelSourceMap.get(KNNConstants.COMPRESSION_LEVEL_PARAMETER); if (trainingNodeAssignment == null) { trainingNodeAssignment = ""; @@ -419,7 +457,9 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m objectToString(error), objectToString(trainingNodeAssignment), (MethodComponentContext) methodComponentContext, - VectorDataType.get(objectToString(vectorDataType)) + VectorDataType.get(objectToString(vectorDataType)), + Mode.fromName(objectToString(mode)), + CompressionLevel.fromName(objectToString(compressionLevel)) ); return modelMetadata; } @@ -442,6 +482,10 @@ public void writeTo(StreamOutput out) throws IOException { if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { out.writeString(vectorDataType.getValue()); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) { + out.writeOptionalString(mode.getName()); + out.writeOptionalString(compressionLevel.getName()); + } } @Override @@ -465,6 +509,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); } + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) { + if (Mode.isConfigured(mode)) { + builder.field(KNNConstants.MODE_PARAMETER, mode.getName()); + } + if (CompressionLevel.isConfigured(compressionLevel)) { + builder.field(KNNConstants.COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName()); + } + } return builder; } } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index 58bcd1ebf0..c9038f0c75 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -15,9 +15,12 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.TrainingJobRouterAction; @@ -91,6 +94,9 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr int maximumVectorCount = DEFAULT_NOT_SET_INT_VALUE; int searchSize = DEFAULT_NOT_SET_INT_VALUE; + String compressionLevel = null; + String mode = null; + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); @@ -115,6 +121,10 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr ModelUtil.blockCommasInModelDescription(description); } else if (VECTOR_DATA_TYPE_FIELD.equals(fieldName) && ensureNotSet(fieldName, vectorDataType)) { vectorDataType = VectorDataType.get(parser.text()); + } else if (KNNConstants.MODE_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, mode)) { + mode = parser.text(); + } else if (KNNConstants.COMPRESSION_LEVEL_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, compressionLevel)) { + compressionLevel = parser.text(); } else { throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); } @@ -143,7 +153,9 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr trainingField, preferredNodeId, description, - vectorDataType + vectorDataType, + Mode.fromName(mode), + CompressionLevel.fromName(compressionLevel) ); if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 3634d13f08..fdc82526de 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -22,6 +22,8 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; @@ -50,6 +52,8 @@ public class TrainingModelRequest extends ActionRequest { private int maximumVectorCount; private int searchSize; private int trainingDataSizeInKB; + private final Mode mode; + private final CompressionLevel compressionLevel; /** * Constructor. @@ -70,7 +74,9 @@ public TrainingModelRequest( String trainingField, String preferredNodeId, String description, - VectorDataType vectorDataType + VectorDataType vectorDataType, + Mode mode, + CompressionLevel compressionLevel ) { super(); this.modelId = modelId; @@ -94,6 +100,8 @@ public TrainingModelRequest( .dimension(dimension) .versionCreated(Version.CURRENT) .build(); + this.mode = mode; + this.compressionLevel = compressionLevel; } /** @@ -119,6 +127,14 @@ public TrainingModelRequest(StreamInput in) throws IOException { } else { this.vectorDataType = VectorDataType.DEFAULT; } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) { + this.mode = Mode.fromName(in.readOptionalString()); + this.compressionLevel = CompressionLevel.fromName(in.readOptionalString()); + } else { + this.mode = Mode.NOT_CONFIGURED; + this.compressionLevel = CompressionLevel.NOT_CONFIGURED; + } + this.knnMethodConfigContext = KNNMethodConfigContext.builder() .vectorDataType(vectorDataType) .dimension(dimension) @@ -271,5 +287,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeString(VectorDataType.DEFAULT.getValue()); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) { + out.writeOptionalString(mode.getName()); + out.writeOptionalString(compressionLevel.getName()); + } } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 963142c1f3..aa1db8c6a3 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -78,7 +78,9 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener modelAnonymousEntryContext, request.getKnnMethodConfigContext(), request.getDescription(), - clusterService.localNode().getEphemeralId() + clusterService.localNode().getEphemeralId(), + request.getMode(), + request.getCompressionLevel() ); KNNCounter.TRAINING_REQUESTS.increment(); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index e30d860db6..63df79bde2 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -19,6 +19,8 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -70,7 +72,9 @@ public TrainingJob( NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, KNNMethodConfigContext knnMethodConfigContext, String description, - String nodeAssignment + String nodeAssignment, + Mode mode, + CompressionLevel compressionLevel ) { // Generate random base64 string if one is not provided this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); @@ -90,7 +94,9 @@ public TrainingJob( "", nodeAssignment, knnMethodContext.getMethodComponentContext(), - knnMethodConfigContext.getVectorDataType() + knnMethodConfigContext.getVectorDataType(), + mode, + compressionLevel ), null, this.modelId diff --git a/src/main/resources/mappings/model-index.json b/src/main/resources/mappings/model-index.json index cd2a508395..e7879cced4 100644 --- a/src/main/resources/mappings/model-index.json +++ b/src/main/resources/mappings/model-index.json @@ -32,6 +32,12 @@ }, "method_component_context": { "type": "keyword" + }, + "mode": { + "type": "keyword" + }, + "compression_level": { + "type": "keyword" } } } diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index f0e60ca98a..28ef41e047 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -18,6 +18,8 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.Model; @@ -65,7 +67,9 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException "", "test-node", MethodComponentContext.EMPTY, - VectorDataType.FLOAT + VectorDataType.FLOAT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); Model model = new Model(modelMetadata, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index f49587bc54..786061af87 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -28,6 +28,8 @@ import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -469,7 +471,9 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio "", "", MethodComponentContext.EMPTY, - VectorDataType.FLOAT + VectorDataType.FLOAT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBytes, modelId diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 2a4e26a82e..174441df83 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -29,7 +29,9 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; +import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.query.BaseQueryFactory; import org.opensearch.knn.index.query.KNNQueryFactory; import org.opensearch.knn.jni.JNIService; @@ -243,7 +245,9 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio "", "", MethodComponentContext.EMPTY, - VectorDataType.FLOAT + VectorDataType.FLOAT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); Model mockModel = new Model(modelMetadata1, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java new file mode 100644 index 0000000000..07475109ad --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.core.common.Strings; +import org.opensearch.knn.KNNTestCase; + +public class CompressionLevelTests extends KNNTestCase { + + public void testFromName() { + assertEquals(CompressionLevel.NOT_CONFIGURED, CompressionLevel.fromName(null)); + assertEquals(CompressionLevel.NOT_CONFIGURED, CompressionLevel.fromName("")); + assertEquals(CompressionLevel.x1, CompressionLevel.fromName("1x")); + assertEquals(CompressionLevel.x32, CompressionLevel.fromName("32x")); + expectThrows(IllegalArgumentException.class, () -> CompressionLevel.fromName("x1")); + } + + public void testGetName() { + assertTrue(Strings.isEmpty(CompressionLevel.NOT_CONFIGURED.getName())); + assertEquals("4x", CompressionLevel.x4.getName()); + assertEquals("16x", CompressionLevel.x16.getName()); + } + + public void testNumBitsForFloat32() { + assertEquals(1, CompressionLevel.x32.numBitsForFloat32()); + assertEquals(2, CompressionLevel.x16.numBitsForFloat32()); + assertEquals(4, CompressionLevel.x8.numBitsForFloat32()); + assertEquals(8, CompressionLevel.x4.numBitsForFloat32()); + assertEquals(16, CompressionLevel.x2.numBitsForFloat32()); + assertEquals(32, CompressionLevel.x1.numBitsForFloat32()); + assertEquals(32, CompressionLevel.NOT_CONFIGURED.numBitsForFloat32()); + } + + public void testIsConfigured() { + assertFalse(CompressionLevel.isConfigured(CompressionLevel.NOT_CONFIGURED)); + assertFalse(CompressionLevel.isConfigured(null)); + assertTrue(CompressionLevel.isConfigured(CompressionLevel.x1)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 369f38cf95..f04c1a4f6a 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -32,6 +32,7 @@ import org.opensearch.index.mapper.ContentPath; import org.opensearch.index.mapper.FieldMapper; import org.opensearch.index.mapper.Mapper; +import org.opensearch.index.mapper.MapperParsingException; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.KNNTestCase; @@ -65,6 +66,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.Version.CURRENT; +import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; @@ -77,6 +79,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -109,11 +112,27 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT, null, null); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder( + fieldName, + modelDao, + CURRENT, + null, + new OriginalMappingParameters(VectorDataType.DEFAULT, TEST_DIMENSION, null, null, null, null) + ); - assertEquals(7, builder.getParameters().size()); + assertEquals(9, builder.getParameters().size()); List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); - List expectedParams = Arrays.asList("store", "doc_values", DIMENSION, VECTOR_DATA_TYPE_FIELD, "meta", KNN_METHOD, MODEL_ID); + List expectedParams = Arrays.asList( + "store", + "doc_values", + DIMENSION, + VECTOR_DATA_TYPE_FIELD, + "meta", + KNN_METHOD, + MODEL_ID, + MODE_PARAMETER, + COMPRESSION_LEVEL_PARAMETER + ); assertEquals(expectedParams, actualParams); } @@ -200,12 +219,15 @@ public void testBuilder_build_fromModel() { "", "", MethodComponentContext.EMPTY, - VectorDataType.FLOAT + VectorDataType.FLOAT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); builder.modelId.setValue(modelId); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); when(modelDao.getMetadata(modelId)).thenReturn(mockedModelMetadata); + builder.setOriginalParameters(new OriginalMappingParameters(builder)); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof ModelFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isPresent()); @@ -396,6 +418,78 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws ); } + @SneakyThrows + public void testTypeParser_parse_compressionAndModeParameter() { + String fieldName = "test-field-name-vec"; + String indexName = "test-index-name-vec"; + + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); + + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + XContentBuilder xContentBuilder1 = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 10) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x16.getName()) + .endObject(); + + Mapper.Builder builder = typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder1), + buildParserContext(indexName, settings) + ); + + assertTrue(builder instanceof KNNVectorFieldMapper.Builder); + assertEquals(Mode.ON_DISK.getName(), ((KNNVectorFieldMapper.Builder) builder).mode.get()); + assertEquals(CompressionLevel.x16.getName(), ((KNNVectorFieldMapper.Builder) builder).compressionLevel.get()); + + XContentBuilder xContentBuilder2 = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 10) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + .field(MODE_PARAMETER, "invalid") + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x16.getName()) + .endObject(); + + expectThrows( + MapperParsingException.class, + () -> typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder2), buildParserContext(indexName, settings)) + ); + + XContentBuilder xContentBuilder3 = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 10) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + .field(COMPRESSION_LEVEL_PARAMETER, "invalid") + .endObject(); + + expectThrows( + MapperParsingException.class, + () -> typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder3), buildParserContext(indexName, settings)) + ); + + XContentBuilder xContentBuilder4 = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 10) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) + .field(MODEL_ID, "test") + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x16.getName()) + .endObject(); + + expectThrows( + MapperParsingException.class, + () -> typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder4), buildParserContext(indexName, settings)) + ); + } + // Validate TypeParser parsing invalid vector data_type which throws exception @SneakyThrows public void testTypeParser_parse_invalidVectorDataType() { @@ -717,7 +811,9 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { "", "", MethodComponentContext.EMPTY, - VectorDataType.FLOAT + VectorDataType.FLOAT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata); @@ -796,18 +892,27 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT when(parseContext.parser()).thenReturn(createXContentParser(dataType)); utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true); + + OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( + dataType, + dimension, + knnMethodContext, + Mode.NOT_CONFIGURED.getName(), + CompressionLevel.NOT_CONFIGURED.getName(), + null + ); + originalMappingParameters.setResolvedKnnMethodContext(knnMethodContext); MethodFieldMapper methodFieldMapper = MethodFieldMapper.createFieldMapper( TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), - knnMethodContext, knnMethodConfigContext, - knnMethodContext, FieldMapper.MultiFields.empty(), FieldMapper.CopyTo.empty(), new Explicit<>(true, true), false, - false + false, + originalMappingParameters ); methodFieldMapper.parseCreateField(parseContext, dimension, dataType); @@ -840,14 +945,13 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), - knnMethodContext, knnMethodConfigContext, - knnMethodContext, FieldMapper.MultiFields.empty(), FieldMapper.CopyTo.empty(), new Explicit<>(true, true), false, - false + false, + originalMappingParameters ); methodFieldMapper.parseCreateField(parseContext, dimension, dataType); @@ -889,19 +993,29 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy when(parseContext.parser()).thenReturn(createXContentParser(dataType)); utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true); + + OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( + VectorDataType.DEFAULT, + -1, + null, + Mode.NOT_CONFIGURED.getName(), + CompressionLevel.NOT_CONFIGURED.getName(), + MODEL_ID + ); + ModelFieldMapper modelFieldMapper = ModelFieldMapper.createFieldMapper( TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), dataType, - MODEL_ID, FieldMapper.MultiFields.empty(), FieldMapper.CopyTo.empty(), new Explicit<>(true, true), false, false, modelDao, - CURRENT + CURRENT, + originalMappingParameters ); modelFieldMapper.parseCreateField(parseContext); @@ -935,14 +1049,14 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy TEST_FIELD_NAME, Collections.emptyMap(), dataType, - MODEL_ID, FieldMapper.MultiFields.empty(), FieldMapper.CopyTo.empty(), new Explicit<>(true, true), false, false, modelDao, - CURRENT + CURRENT, + originalMappingParameters ); modelFieldMapper.parseCreateField(parseContext); @@ -971,12 +1085,23 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { .versionCreated(CURRENT) .dimension(TEST_DIMENSION) .build(); + + OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( + VectorDataType.FLOAT, + TEST_DIMENSION, + getDefaultKNNMethodContext(), + Mode.NOT_CONFIGURED.getName(), + CompressionLevel.NOT_CONFIGURED.getName(), + null + ); + originalMappingParameters.setResolvedKnnMethodContext(originalMappingParameters.getKnnMethodContext()); + LuceneFieldMapper luceneFieldMapper = LuceneFieldMapper.createFieldMapper( TEST_FIELD_NAME, Collections.emptyMap(), - getDefaultKNNMethodContext(), knnMethodConfigContext, - inputBuilder.build() + inputBuilder.build(), + originalMappingParameters ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1020,12 +1145,21 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { .build(); MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.DEFAULT, methodComponentContext); + originalMappingParameters = new OriginalMappingParameters( + VectorDataType.FLOAT, + TEST_DIMENSION, + knnMethodContext, + Mode.NOT_CONFIGURED.getName(), + CompressionLevel.NOT_CONFIGURED.getName(), + null + ); + originalMappingParameters.setResolvedKnnMethodContext(originalMappingParameters.getKnnMethodContext()); luceneFieldMapper = LuceneFieldMapper.createFieldMapper( TEST_FIELD_NAME, Collections.emptyMap(), - knnMethodContext, knnMethodConfigContext, - inputBuilder.build() + inputBuilder.build(), + originalMappingParameters ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1051,17 +1185,27 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.doc()).thenReturn(document); when(parseContext.path()).thenReturn(contentPath); + OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( + VectorDataType.BYTE, + TEST_DIMENSION, + getDefaultByteKNNMethodContext(), + Mode.NOT_CONFIGURED.getName(), + CompressionLevel.NOT_CONFIGURED.getName(), + null + ); + originalMappingParameters.setResolvedKnnMethodContext(originalMappingParameters.getKnnMethodContext()); + LuceneFieldMapper luceneFieldMapper = Mockito.spy( LuceneFieldMapper.createFieldMapper( TEST_FIELD_NAME, Collections.emptyMap(), - getDefaultByteKNNMethodContext(), KNNMethodConfigContext.builder() .vectorDataType(VectorDataType.BYTE) .versionCreated(CURRENT) .dimension(TEST_DIMENSION) .build(), - inputBuilder.build() + inputBuilder.build(), + originalMappingParameters ) ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) @@ -1100,17 +1244,18 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.path()).thenReturn(contentPath); inputBuilder.hasDocValues(false); + luceneFieldMapper = Mockito.spy( LuceneFieldMapper.createFieldMapper( TEST_FIELD_NAME, Collections.emptyMap(), - getDefaultByteKNNMethodContext(), KNNMethodConfigContext.builder() .vectorDataType(VectorDataType.BYTE) .versionCreated(CURRENT) .dimension(TEST_DIMENSION) .build(), - inputBuilder.build() + inputBuilder.build(), + originalMappingParameters ) ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) @@ -1185,7 +1330,7 @@ private void testTypeParserWithBinaryDataType( buildParserContext(indexName, settings) ); - assertEquals(spaceType, builder.getResolvedKNNMethodContext().getSpaceType()); + assertEquals(spaceType, builder.getOriginalParameters().getResolvedKnnMethodContext().getSpaceType()); } else { Exception ex = expectThrows(Exception.class, () -> { typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); @@ -1233,6 +1378,7 @@ public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Setup settings Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, false).build(); + builder.setOriginalParameters(new OriginalMappingParameters(builder)); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof FlatVectorFieldMapper); diff --git a/src/test/java/org/opensearch/knn/index/mapper/ModeTests.java b/src/test/java/org/opensearch/knn/index/mapper/ModeTests.java new file mode 100644 index 0000000000..2035bba803 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/ModeTests.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.core.common.Strings; +import org.opensearch.knn.KNNTestCase; + +public class ModeTests extends KNNTestCase { + + public void testFromName() { + assertEquals(Mode.NOT_CONFIGURED, Mode.fromName(null)); + assertEquals(Mode.NOT_CONFIGURED, Mode.fromName("")); + assertEquals(Mode.ON_DISK, Mode.fromName("on_disk")); + assertEquals(Mode.IN_MEMORY, Mode.fromName("in_memory")); + expectThrows(IllegalArgumentException.class, () -> Mode.fromName("on_disk2")); + } + + public void testGetName() { + assertTrue(Strings.isEmpty(Mode.NOT_CONFIGURED.getName())); + assertEquals("on_disk", Mode.ON_DISK.getName()); + assertEquals("in_memory", Mode.IN_MEMORY.getName()); + } + + public void testIsConfigured() { + assertFalse(Mode.isConfigured(Mode.NOT_CONFIGURED)); + assertFalse(Mode.isConfigured(null)); + assertTrue(Mode.isConfigured(Mode.ON_DISK)); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/OriginalMappingParametersTests.java b/src/test/java/org/opensearch/knn/index/mapper/OriginalMappingParametersTests.java new file mode 100644 index 0000000000..2822a882e7 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/OriginalMappingParametersTests.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; + +import java.util.Collections; + +public class OriginalMappingParametersTests extends KNNTestCase { + + public void testIsLegacy() { + assertTrue(new OriginalMappingParameters(VectorDataType.DEFAULT, 123, null, null, null, null).isLegacyMapping()); + assertFalse(new OriginalMappingParameters(VectorDataType.DEFAULT, 123, null, null, null, "model-id").isLegacyMapping()); + assertFalse(new OriginalMappingParameters(VectorDataType.DEFAULT, 123, null, Mode.ON_DISK.getName(), null, null).isLegacyMapping()); + assertFalse( + new OriginalMappingParameters(VectorDataType.DEFAULT, 123, null, null, CompressionLevel.x2.getName(), null).isLegacyMapping() + ); + assertFalse( + new OriginalMappingParameters( + VectorDataType.DEFAULT, + 123, + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.L2, new MethodComponentContext(null, Collections.emptyMap())), + null, + null, + null + ).isLegacyMapping() + ); + } + +} diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java index 88f78e716b..91bb7d3d91 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -21,6 +21,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import java.time.ZoneOffset; import java.time.ZonedDateTime; @@ -47,7 +49,9 @@ public void testGet_normal() throws ExecutionException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), "hello".getBytes(), modelId @@ -85,7 +89,9 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[BYTES_PER_KILOBYTES + 1], modelId @@ -144,7 +150,9 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[size1], modelId1 @@ -161,7 +169,9 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[size2], modelId2 @@ -206,7 +216,9 @@ public void testRemove_normal() throws ExecutionException, InterruptedException "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[size1], modelId1 @@ -223,7 +235,9 @@ public void testRemove_normal() throws ExecutionException, InterruptedException "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[size2], modelId2 @@ -273,7 +287,9 @@ public void testRebuild_normal() throws ExecutionException, InterruptedException "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), "hello".getBytes(), modelId @@ -320,7 +336,9 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[modelSize], modelId @@ -390,7 +408,9 @@ public void testContains() throws ExecutionException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[modelSize1], modelId1 @@ -433,7 +453,9 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[modelSize1], modelId1 @@ -452,7 +474,9 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[modelSize2], modelId2 @@ -499,7 +523,9 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[BYTES_PER_KILOBYTES * 2], modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index d9dab081c7..560ea59b2f 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -37,6 +37,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; @@ -141,7 +143,9 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId @@ -162,7 +166,9 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId @@ -191,7 +197,9 @@ public void testPut_withId() throws InterruptedException, IOException { "", "", new MethodComponentContext("test", Collections.emptyMap()), - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId @@ -253,7 +261,9 @@ public void testPut_withoutModel() throws InterruptedException, IOException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId @@ -316,7 +326,9 @@ public void testPut_invalid_badState() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, "any-id" @@ -354,7 +366,9 @@ public void testUpdate() throws IOException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), null, modelId @@ -394,7 +408,9 @@ public void testUpdate() throws IOException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId @@ -446,7 +462,9 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId @@ -466,7 +484,9 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), null, modelId @@ -504,7 +524,9 @@ public void testGetMetadata() throws IOException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -582,7 +604,9 @@ public void testDelete() throws IOException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId @@ -617,7 +641,9 @@ public void testDelete() throws IOException, InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId1 @@ -686,7 +712,9 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId @@ -729,7 +757,9 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index 04fa502622..6f0b49285d 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -21,6 +21,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import java.io.IOException; import java.time.ZoneId; @@ -47,14 +49,33 @@ public void testStreams() throws IOException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); BytesStreamOutput streamOutput = new BytesStreamOutput(); modelMetadata.writeTo(streamOutput); - ModelMetadata modelMetadataCopy = new ModelMetadata(streamOutput.bytes().streamInput()); + assertEquals(modelMetadata, modelMetadataCopy); + modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT, + Mode.ON_DISK, + CompressionLevel.x16 + ); + streamOutput = new BytesStreamOutput(); + modelMetadata.writeTo(streamOutput); + modelMetadataCopy = new ModelMetadata(streamOutput.bytes().streamInput()); assertEquals(modelMetadata, modelMetadataCopy); } @@ -70,7 +91,9 @@ public void testGetKnnEngine() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(knnEngine, modelMetadata.getKnnEngine()); @@ -88,7 +111,9 @@ public void testGetSpaceType() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(spaceType, modelMetadata.getSpaceType()); @@ -106,7 +131,9 @@ public void testGetDimension() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(dimension, modelMetadata.getDimension()); @@ -124,7 +151,9 @@ public void testGetState() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(modelState, modelMetadata.getState()); @@ -142,7 +171,9 @@ public void testGetTimestamp() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(timeValue, modelMetadata.getTimestamp()); @@ -160,7 +191,9 @@ public void testDescription() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(description, modelMetadata.getDescription()); @@ -178,7 +211,9 @@ public void testGetError() { error, "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(error, modelMetadata.getError()); @@ -196,7 +231,9 @@ public void testGetVectorDataType() { "", "", MethodComponentContext.EMPTY, - vectorDataType + vectorDataType, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(vectorDataType, modelMetadata.getVectorDataType()); @@ -214,7 +251,9 @@ public void testSetState() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(modelState, modelMetadata.getState()); @@ -236,7 +275,9 @@ public void testSetError() { error, "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(error, modelMetadata.getError()); @@ -275,7 +316,9 @@ public void testToString() { + "," + methodComponentContext.toClusterStateString() + "," - + VectorDataType.DEFAULT.getValue(); + + VectorDataType.DEFAULT.getValue() + + "," + + ","; ModelMetadata modelMetadata = new ModelMetadata( knnEngine, @@ -287,7 +330,50 @@ public void testToString() { error, nodeAssignment, MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + assertEquals(expected, modelMetadata.toString()); + + expected = knnEngine.getName() + + "," + + spaceType.getValue() + + "," + + dimension + + "," + + modelState.getName() + + "," + + timestamp + + "," + + description + + "," + + error + + "," + + nodeAssignment + + "," + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.DEFAULT.getValue() + + "," + + Mode.ON_DISK.getName() + + "," + + CompressionLevel.x32.getName(); + + modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + nodeAssignment, + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT, + Mode.ON_DISK, + CompressionLevel.x32 ); assertEquals(expected, modelMetadata.toString()); @@ -308,7 +394,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -320,7 +408,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -333,7 +423,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -345,7 +437,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -357,7 +451,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -369,7 +465,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -381,7 +479,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -393,7 +493,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -405,7 +507,9 @@ public void testEquals() { "diff error", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -418,7 +522,9 @@ public void testEquals() { "", "", new MethodComponentContext("test", Collections.emptyMap()), - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(modelMetadata1, modelMetadata1); @@ -449,7 +555,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -461,7 +569,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -474,7 +584,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -486,7 +598,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -498,7 +612,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -510,7 +626,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -522,7 +640,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -534,7 +654,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -546,7 +668,9 @@ public void testHashCode() { "diff error", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -559,7 +683,9 @@ public void testHashCode() { "", "", new MethodComponentContext("test", Collections.emptyMap()), - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); @@ -622,6 +748,52 @@ public void testFromString() { + "," + VectorDataType.DEFAULT.getValue(); + String stringRep3 = knnEngine.getName() + + "," + + spaceType.getValue() + + "," + + dimension + + "," + + modelState.getName() + + "," + + timestamp + + "," + + description + + "," + + error + + "," + + nodeAssignment + + "," + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.DEFAULT.getValue() + + "," + + ","; + + String stringRep4 = knnEngine.getName() + + "," + + spaceType.getValue() + + "," + + dimension + + "," + + modelState.getName() + + "," + + timestamp + + "," + + description + + "," + + error + + "," + + nodeAssignment + + "," + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.DEFAULT.getValue() + + "," + + Mode.ON_DISK.getName() + + "," + + CompressionLevel.x32.getName(); + ModelMetadata expected1 = new ModelMetadata( knnEngine, spaceType, @@ -632,7 +804,9 @@ public void testFromString() { error, nodeAssignment, MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata expected2 = new ModelMetadata( @@ -645,14 +819,35 @@ public void testFromString() { error, "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + ModelMetadata expected3 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT, + Mode.ON_DISK, + CompressionLevel.x32 ); ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); ModelMetadata fromString2 = ModelMetadata.fromString(stringRep2); + ModelMetadata fromString3 = ModelMetadata.fromString(stringRep3); + ModelMetadata fromString4 = ModelMetadata.fromString(stringRep4); assertEquals(expected1, fromString1); assertEquals(expected2, fromString2); + assertEquals(expected2, fromString3); + assertEquals(expected3, fromString4); expectThrows(IllegalArgumentException.class, () -> ModelMetadata.fromString("invalid")); } @@ -679,7 +874,9 @@ public void testFromResponseMap() throws IOException { error, nodeAssignment, methodComponentContext, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); ModelMetadata expected2 = new ModelMetadata( knnEngine, @@ -691,7 +888,39 @@ public void testFromResponseMap() throws IOException { error, "", emptyMethodComponentContext, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ); + + ModelMetadata expected3 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + "", + emptyMethodComponentContext, + VectorDataType.DEFAULT, + Mode.ON_DISK, + CompressionLevel.NOT_CONFIGURED + ); + + ModelMetadata expected4 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + "", + emptyMethodComponentContext, + VectorDataType.DEFAULT, + Mode.ON_DISK, + CompressionLevel.x16 ); Map metadataAsMap = new HashMap<>(); metadataAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); @@ -714,6 +943,14 @@ public void testFromResponseMap() throws IOException { metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, null); metadataAsMap.put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, null); assertEquals(expected2, fromMap); + + metadataAsMap.put(KNNConstants.MODE_PARAMETER, Mode.ON_DISK.getName()); + fromMap = ModelMetadata.getMetadataFromSourceMap(metadataAsMap); + assertEquals(expected3, fromMap); + + metadataAsMap.put(KNNConstants.COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x16.getName()); + fromMap = ModelMetadata.getMetadataFromSourceMap(metadataAsMap); + assertEquals(expected4, fromMap); } public void testBlockCommasInDescription() { @@ -739,7 +976,9 @@ public void testBlockCommasInDescription() { error, nodeAssignment, methodComponentContext, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ) ); assertEquals("Model description cannot contain any commas: ','", e.getMessage()); diff --git a/src/test/java/org/opensearch/knn/indices/ModelTests.java b/src/test/java/org/opensearch/knn/indices/ModelTests.java index 45e8b05f10..4e666872ff 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelTests.java @@ -17,6 +17,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import java.time.ZoneOffset; import java.time.ZonedDateTime; @@ -43,7 +45,9 @@ public void testInvalidConstructor() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), null, "test-model" @@ -65,7 +69,9 @@ public void testInvalidDimension() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model" @@ -84,7 +90,9 @@ public void testInvalidDimension() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model" @@ -103,7 +111,9 @@ public void testInvalidDimension() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model" @@ -123,7 +133,9 @@ public void testGetModelMetadata() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); Model model = new Model(modelMetadata, new byte[16], "test-model"); assertEquals(modelMetadata, model.getModelMetadata()); @@ -142,7 +154,9 @@ public void testGetModelBlob() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, "test-model" @@ -163,7 +177,9 @@ public void testGetLength() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[size], "test-model" @@ -181,7 +197,9 @@ public void testGetLength() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), null, "test-model" @@ -202,7 +220,9 @@ public void testSetModelBlob() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), blob1, "test-model" @@ -229,7 +249,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model-1" @@ -245,7 +267,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model-1" @@ -261,7 +285,9 @@ public void testEquals() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model-2" @@ -287,7 +313,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model-1" @@ -303,7 +331,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model-1" @@ -319,7 +349,9 @@ public void testHashCode() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[16], "test-model-2" @@ -351,7 +383,9 @@ public void testModelFromSourceMap() { error, nodeAssignment, MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); Map modelAsMap = new HashMap<>(); modelAsMap.put(KNNConstants.MODEL_ID, modelID); @@ -365,6 +399,8 @@ public void testModelFromSourceMap() { modelAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); modelAsMap.put(KNNConstants.MODEL_BLOB_PARAMETER, "aGVsbG8="); modelAsMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()); + modelAsMap.put(KNNConstants.MODE_PARAMETER, Mode.NOT_CONFIGURED.getName()); + modelAsMap.put(KNNConstants.COMPRESSION_LEVEL_PARAMETER, CompressionLevel.NOT_CONFIGURED.getName()); byte[] blob1 = "hello".getBytes(); Model expected = new Model(metadata, blob1, modelID); diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java new file mode 100644 index 0000000000..bf2e7f0c0f --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import lombok.SneakyThrows; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.io.IOException; + +import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; +import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; + +public class ModeAndCompressionIT extends KNNRestTestCase { + + private static final int DIMENSION = 10; + + public void testIndexCreation() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .field(MODE_PARAMETER, "on_disk") + .field(COMPRESSION_LEVEL_PARAMETER, "16x") + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + createKnnIndex(INDEX_NAME + "1", mapping); + + builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .field(MODE_PARAMETER, "in_memory") + .field(COMPRESSION_LEVEL_PARAMETER, "32x") + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + mapping = builder.toString(); + createKnnIndex(INDEX_NAME + "2", mapping); + + builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .field(MODE_PARAMETER, "invalid") + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + String finalMapping = builder.toString(); + expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME + "3", finalMapping)); + } + + @SneakyThrows + public void testTraining() { + String trainingIndexName = "training-index"; + String trainingFieldName = "training-field"; + String modelDescription = "test model"; + int dimension = 20; + int trainingDataCount = 256; + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + String modelId1 = "test-model-1"; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, trainingIndexName) + .field(TRAIN_FIELD_PARAMETER, trainingFieldName) + .field(KNNConstants.DIMENSION, dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .endObject() + .field(MODEL_DESCRIPTION, modelDescription) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x16.getName()) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .endObject(); + Response trainResponse = trainModel(modelId1, xContentBuilder); + assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); + assertTrainingSucceeds(modelId1, 360, 1000); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("model_id", modelId1) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + createKnnIndex(INDEX_NAME + "1", mapping); + deleteKNNIndex(INDEX_NAME + "1"); + deleteModel(modelId1); + String modelId2 = "test-model-2"; + XContentBuilder xContentBuilder2 = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, trainingIndexName) + .field(TRAIN_FIELD_PARAMETER, trainingFieldName) + .field(KNNConstants.DIMENSION, dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .endObject() + .field(MODEL_DESCRIPTION, modelDescription) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x16.getName()) + .field(MODE_PARAMETER, "invalid") + .endObject(); + expectThrows(ResponseException.class, () -> trainModel(modelId2, xContentBuilder2)); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index 71e8f15a69..7010dbf43f 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -18,6 +18,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -46,7 +48,9 @@ private ModelMetadata getModelMetadata(ModelState state) { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java index 8fdccdac0d..6252a29ac4 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java @@ -21,6 +21,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; @@ -80,7 +82,9 @@ public void testNodeOperation_modelInCache() throws ExecutionException, Interrup "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), new byte[128], modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 8cff4dfa14..151626ef5b 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -24,6 +24,8 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.transport.TransportService; @@ -308,7 +310,9 @@ public void testTrainingIndexSize() { "training-field", null, "description", - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock client to return the right number of docs @@ -355,7 +359,9 @@ public void testTrainIndexSize_whenDataTypeIsBinary() { "training-field", null, "description", - VectorDataType.BINARY + VectorDataType.BINARY, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock client to return the right number of docs @@ -403,7 +409,9 @@ public void testTrainIndexSize_whenDataTypeIsByte() { "training-field", null, "description", - VectorDataType.BYTE + VectorDataType.BYTE, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock client to return the right number of docs diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index d7920d9877..10f35457da 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -26,9 +26,11 @@ import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -64,7 +66,9 @@ public void testStreams() throws IOException { trainingField, preferredNode, description, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -78,6 +82,8 @@ public void testStreams() throws IOException { assertEquals(original1.getTrainingField(), copy1.getTrainingField()); assertEquals(original1.getPreferredNodeId(), copy1.getPreferredNodeId()); assertEquals(original1.getVectorDataType(), copy1.getVectorDataType()); + assertEquals(original1.getMode(), copy1.getMode()); + assertEquals(original1.getCompressionLevel(), copy1.getCompressionLevel()); // Also, check when preferred node and model id and description are null TrainingModelRequest original2 = new TrainingModelRequest( @@ -88,7 +94,9 @@ public void testStreams() throws IOException { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); streamOutput = new BytesStreamOutput(); @@ -102,6 +110,33 @@ public void testStreams() throws IOException { assertEquals(original2.getTrainingField(), copy2.getTrainingField()); assertEquals(original2.getPreferredNodeId(), copy2.getPreferredNodeId()); assertEquals(original2.getVectorDataType(), copy2.getVectorDataType()); + assertEquals(original2.getMode(), copy2.getMode()); + assertEquals(original2.getCompressionLevel(), copy2.getCompressionLevel()); + + TrainingModelRequest original3 = new TrainingModelRequest( + null, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null, + VectorDataType.DEFAULT, + Mode.ON_DISK, + CompressionLevel.x32 + ); + + streamOutput = new BytesStreamOutput(); + original3.writeTo(streamOutput); + TrainingModelRequest copy3 = new TrainingModelRequest(streamOutput.bytes().streamInput()); + + assertEquals(original3.getModelId(), copy3.getModelId()); + assertEquals(original3.getKnnMethodContext(), copy3.getKnnMethodContext()); + assertEquals(original3.getDimension(), copy3.getDimension()); + assertEquals(original3.getTrainingIndex(), copy3.getTrainingIndex()); + assertEquals(original3.getTrainingField(), copy3.getTrainingField()); + assertEquals(original3.getMode(), copy3.getMode()); + assertEquals(original3.getCompressionLevel(), copy3.getCompressionLevel()); } public void testGetters() { @@ -124,7 +159,9 @@ public void testGetters() { trainingField, preferredNode, description, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); trainingModelRequest.setMaximumVectorCount(maxVectorCount); @@ -164,7 +201,9 @@ public void testValidation_invalid_modelIdAlreadyExists() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -179,7 +218,9 @@ public void testValidation_invalid_modelIdAlreadyExists() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); @@ -221,7 +262,9 @@ public void testValidation_blocked_modelId() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return true to recognize that the modelId is in graveyard @@ -269,7 +312,9 @@ public void testValidation_invalid_invalidMethodContext() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return null so that no exception is produced @@ -313,7 +358,9 @@ public void testValidation_invalid_trainingIndexDoesNotExist() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return null so that no exception is produced @@ -360,7 +407,9 @@ public void testValidation_invalid_trainingFieldDoesNotExist() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return null so that no exception is produced @@ -412,7 +461,9 @@ public void testValidation_invalid_trainingFieldNotKnnVector() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return null so that no exception is produced @@ -469,7 +520,9 @@ public void testValidation_invalid_dimensionDoesNotMatch() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return null so that no exception is produced @@ -528,7 +581,9 @@ public void testValidation_invalid_preferredNodeDoesNotExist() { trainingField, preferredNode, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -595,7 +650,9 @@ public void testValidation_invalid_descriptionToLong() { trainingField, null, description, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -641,7 +698,9 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -680,7 +739,9 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { trainingField, null, null, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java index aea0e0b16c..345f4feb5a 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java @@ -19,6 +19,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.indices.ModelDao; import java.io.IOException; @@ -74,7 +76,9 @@ public void testDoExecute() throws InterruptedException, ExecutionException, IOE trainingFieldName, null, "test-detector", - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); trainingModelRequest.setTrainingDataSizeInKB(estimateVectorSetSizeInKB(trainingDataCount, dimension, VectorDataType.DEFAULT)); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index bc6e098f34..45203dae68 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -19,6 +19,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelGraveyard; import org.opensearch.knn.indices.ModelMetadata; @@ -212,7 +214,9 @@ public void testClusterManagerOperation_GetIndicesUsingModel() throws IOExceptio "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java index 238fc5e45c..d0a83ccc5a 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java @@ -17,6 +17,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -44,7 +46,9 @@ public void testStreams() throws IOException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, isRemoveRequest, modelMetadata); @@ -70,7 +74,9 @@ public void testValidate() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); UpdateModelMetadataRequest updateModelMetadataRequest1 = new UpdateModelMetadataRequest("test", true, null); @@ -111,7 +117,9 @@ public void testGetModelMetadata() { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest("test", true, modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index e5dcb2257e..d317fa893c 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -21,6 +21,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; import org.opensearch.threadpool.ThreadPool; @@ -70,7 +72,9 @@ public void testClusterManagerOperation() throws InterruptedException { "", "", MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); // Get update transport action diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index adecca43a6..32794a33b8 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -20,6 +20,8 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; @@ -69,7 +71,9 @@ public void testGetModelId() { mock(NativeMemoryEntryContext.AnonymousEntryContext.class), KNNMethodConfigContext.builder().vectorDataType(VectorDataType.DEFAULT).dimension(10).versionCreated(Version.CURRENT).build(), "", - "test-node" + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); assertEquals(modelId, trainingJob.getModelId()); @@ -102,7 +106,9 @@ public void testGetModel() { .versionCreated(Version.CURRENT) .build(), description, - nodeAssignment + nodeAssignment, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); Model model = new Model( @@ -116,7 +122,9 @@ public void testGetModel() { error, nodeAssignment, MethodComponentContext.EMPTY, - VectorDataType.DEFAULT + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ), null, modelID @@ -195,7 +203,9 @@ public void testRun_success() throws IOException, ExecutionException { modelContext, knnMethodConfigContext, "", - "test-node" + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); trainingJob.run(); @@ -278,7 +288,9 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept modelContext, knnMethodConfigContext, "", - "test-node" + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); trainingJob.run(); @@ -350,7 +362,9 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce modelContext, knnMethodConfigContext, "", - "test-node" + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); trainingJob.run(); @@ -421,7 +435,9 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep mock(NativeMemoryEntryContext.AnonymousEntryContext.class), knnMethodConfigContext, "", - "test-node" + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); trainingJob.run(); @@ -499,7 +515,9 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { modelContext, knnMethodConfigContext, "", - "test-node" + "test-node", + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED ); trainingJob.run();