Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

/**
* Abstract {@link MethodResolver} with helpful utilitiy functions that can be shared across different
* implementations
*/
public abstract class AbstractMethodResolver implements MethodResolver {

/**
* Utility method to get the compression level from the context
*
* @param resolvedKnnMethodContext Resolved method context. Should have an encoder set in the params if available
* @return {@link CompressionLevel} Compression level that is configured with the {@link KNNMethodContext}
*/
protected CompressionLevel resolveCompressionLevelFromMethodContext(
KNNMethodContext resolvedKnnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
Map<String, Encoder> encoderMap
) {
// If the context is null, the compression is not configured or the encoder is not defined, return not configured
// because the method context does not contain this info
if (isEncoderSpecified(resolvedKnnMethodContext) == false) {
return CompressionLevel.x1;
}
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext));
if (encoder == null) {
return CompressionLevel.NOT_CONFIGURED;
}
return encoder.calculateCompressionLevel(getEncoderComponentContext(resolvedKnnMethodContext), knnMethodConfigContext);
}

protected void resolveMethodParams(
MethodComponentContext methodComponentContext,
KNNMethodConfigContext knnMethodConfigContext,
MethodComponent methodComponent
) {
Map<String, Object> resolvedParams = MethodComponent.getParameterMapWithDefaultsAdded(
methodComponentContext,
methodComponent,
knnMethodConfigContext
);
methodComponentContext.getParameters().putAll(resolvedParams);
}

protected KNNMethodContext initResolvedKNNMethodContext(
KNNMethodContext originalMethodContext,
KNNEngine knnEngine,
SpaceType spaceType,
String methodName
) {
if (originalMethodContext == null) {
return new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(methodName, new HashMap<>()));
}
return new KNNMethodContext(originalMethodContext);
}

protected String getEncoderName(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext);
if (methodComponentContext == null) {
return null;
}

return methodComponentContext.getName();
}

protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER);
}

/**
* Determine if the encoder parameter is specified
*
* @param knnMethodContext {@link KNNMethodContext}
* @return true is the encoder is specified in the structure; false otherwise
*/
protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) {
return knnMethodContext != null
&& knnMethodContext.getMethodComponentContext().getParameters() != null
&& knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER);
}

protected boolean shouldEncoderBeResolved(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
// The encoder should not be resolved if:
// 1. The encoder is specified
// 2. The compression is x1
// 3. The compression is not specified and the mode is not disk-based
if (isEncoderSpecified(knnMethodContext)) {
return false;
}

if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1) {
return false;
}

if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel()) == false
&& Mode.ON_DISK != knnMethodConfigContext.getMode()) {
return false;
}

if (VectorDataType.FLOAT != knnMethodConfigContext.getVectorDataType()) {
return false;
}

return true;
}

protected ValidationException validateNotTrainingContext(
boolean shouldRequireTraining,
KNNEngine knnEngine,
ValidationException validationException
) {
if (shouldRequireTraining) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "Cannot use \"%s\" engine from training context", knnEngine.getName())
);
}

return validationException;
}

protected ValidationException validateCompressionSupported(
CompressionLevel compressionLevel,
Set<CompressionLevel> supportedCompressionLevels,
KNNEngine knnEngine,
ValidationException validationException
) {
if (CompressionLevel.isConfigured(compressionLevel) && supportedCompressionLevels.contains(compressionLevel) == false) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "\"%s\" does not support \"%s\" compression", knnEngine.getName(), compressionLevel.getName())
);
}
return validationException;
}

protected ValidationException validateCompressionNotx1WhenOnDisk(
KNNMethodConfigContext knnMethodConfigContext,
ValidationException validationException
) {
if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1 && knnMethodConfigContext.getMode() == Mode.ON_DISK) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "Cannot specify \"x1\" compression level when using \"%s\" mode", Mode.ON_DISK.getName())
);
}
return validationException;
}

protected void validateCompressionConflicts(CompressionLevel originalCompressionLevel, CompressionLevel resolvedCompressionLevel) {
if (CompressionLevel.isConfigured(originalCompressionLevel)
&& CompressionLevel.isConfigured(resolvedCompressionLevel)
&& resolvedCompressionLevel != originalCompressionLevel) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("Cannot specify an encoder that conflicts with the provided compression level");
throw validationException;
}
}
}
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.index.engine;

import org.opensearch.knn.index.mapper.CompressionLevel;

/**
* Interface representing an encoder. An encoder generally refers to a vector quantizer.
*/
Expand All @@ -24,4 +26,14 @@ default String getName() {
* @return Method component associated with the encoder
*/
MethodComponent getMethodComponent();

/**
* Calculate the compression level for the give params. Assume float32 vectors are used. All parameters should
* be resolved in the encoderContext passed in.
*
* @param encoderContext Context for the encoder to extract params from
* @return Compression level this encoder produces. If the encoder does not support this calculation yet, it will
* return {@link CompressionLevel#NOT_CONFIGURED}
*/
CompressionLevel calculateCompressionLevel(MethodComponentContext encoderContext, KNNMethodConfigContext knnMethodConfigContext);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

/**
* Figures out what {@link KNNEngine} to use based on configuration details
*/
public final class EngineResolver {

public static final EngineResolver INSTANCE = new EngineResolver();

private EngineResolver() {}

/**
* Based on the provided {@link Mode} and {@link CompressionLevel}, resolve to a {@link KNNEngine}.
*
* @param knnMethodConfigContext configuration context
* @param knnMethodContext KNNMethodContext
* @param requiresTraining whether config requires training
* @return {@link KNNEngine}
*/
public KNNEngine resolveEngine(
KNNMethodConfigContext knnMethodConfigContext,
KNNMethodContext knnMethodContext,
boolean requiresTraining
) {
// User configuration gets precedence
if (knnMethodContext != null && knnMethodContext.isEngineConfigured()) {
return knnMethodContext.getKnnEngine();
}

// Faiss is the only engine that supports training, so we default to faiss here for now
if (requiresTraining) {
return KNNEngine.FAISS;
}

Mode mode = knnMethodConfigContext.getMode();
CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel();
// If both mode and compression are not specified, we can just default
if (Mode.isConfigured(mode) == false && CompressionLevel.isConfigured(compressionLevel) == false) {
return KNNEngine.DEFAULT;
}

// For 1x, we need to default to faiss if mode is provided and use nmslib otherwise
if (CompressionLevel.isConfigured(compressionLevel) == false || compressionLevel == CompressionLevel.x1) {
return mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.DEFAULT;
}

// Lucene is only engine that supports 4x - so we have to default to it here.
if (compressionLevel == CompressionLevel.x4) {
return KNNEngine.LUCENE;
}

return KNNEngine.FAISS;
}
}
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,14 @@ public void setInitialized(Boolean isInitialized) {
public List<String> mmapFileExtensions() {
return knnLibrary.mmapFileExtensions();
}

@Override
public ResolvedMethodContext resolveMethod(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
boolean shouldRequireTraining,
final SpaceType spaceType
) {
return knnLibrary.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
/**
* KNNLibrary is an interface that helps the plugin communicate with k-NN libraries
*/
public interface KNNLibrary {
public interface KNNLibrary extends MethodResolver {

/**
* Gets the version of the library that is being used. In general, this can be used for ensuring compatibility of
Expand Down
Loading