diff --git a/docs/changelog/140477.yaml b/docs/changelog/140477.yaml new file mode 100644 index 0000000000000..c9112ce326c32 --- /dev/null +++ b/docs/changelog/140477.yaml @@ -0,0 +1,5 @@ +pr: 140477 +summary: "[Inference API] Add Mixedbread Rerank support to the Inference Plugin" +area: Inference +type: enhancement +issues: [] diff --git a/server/src/main/resources/transport/definitions/referable/inference_mixedbread_added.csv b/server/src/main/resources/transport/definitions/referable/inference_mixedbread_added.csv new file mode 100644 index 0000000000000..e270b67a7c44f --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/inference_mixedbread_added.csv @@ -0,0 +1 @@ +9283000 diff --git a/server/src/main/resources/transport/upper_bounds/9.4.csv b/server/src/main/resources/transport/upper_bounds/9.4.csv index d1b7b2100f561..074433fd98de2 100644 --- a/server/src/main/resources/transport/upper_bounds/9.4.csv +++ b/server/src/main/resources/transport/upper_bounds/9.4.csv @@ -1 +1 @@ -esql_batch_page,9282000 +inference_mixedbread_added,9283000 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index a249114dcc3f1..52a6c7cec9bc1 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -81,7 +81,8 @@ public void testGetServicesWithoutTaskType() throws IOException { "text_embedding_test_service", "voyageai", "watsonxai", - "amazon_sagemaker" + "amazon_sagemaker", + "mixedbread" ).toArray() ) ); @@ -145,6 +146,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { "elasticsearch", "googlevertexai", "jinaai", + "mixedbread", "nvidia", "openshift_ai", "test_reranking_service", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 3171118ea25c9..0dacb49fd1f66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -115,6 +115,8 @@ import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettings; import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsTaskSettings; @@ -187,6 +189,7 @@ public static List getNamedWriteables() { addAi21NamedWriteables(namedWriteables); addOpenShiftAiNamedWriteables(namedWriteables); addNvidiaNamedWriteables(namedWriteables); + addMixedbreadNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -942,4 +945,17 @@ private static void addElasticNamedWriteables(List ) ); } + + private static void addMixedbreadNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + MixedbreadRerankServiceSettings.NAME, + MixedbreadRerankServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, MixedbreadRerankTaskSettings.NAME, MixedbreadRerankTaskSettings::new) + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index d8f8e84a38030..6e8fbd16bb0b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -172,6 +172,7 @@ import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; import org.elasticsearch.xpack.inference.services.llama.LlamaService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadService; import org.elasticsearch.xpack.inference.services.nvidia.NvidiaService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.openshiftai.OpenShiftAiService; @@ -561,6 +562,7 @@ public List getInferenceServiceFactories() { context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get(), context), context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get(), context), context -> new MistralService(httpFactory.get(), serviceComponents.get(), context), + context -> new MixedbreadService(httpFactory.get(), serviceComponents.get(), context), context -> new AnthropicService(httpFactory.get(), serviceComponents.get(), context), context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get(), context), context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get(), context), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 7a28084dace09..ed1e9e80eb7fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -118,4 +118,8 @@ public static RestStatus toRestStatus(int statusCode) { return code == null ? RestStatus.BAD_REQUEST : code; } + + protected static String resourceNotFoundError(Request request) { + return format("Resource not found at [%s]", request.getURI()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioResponseHandler.java index 32a436e9e97cd..c71c21d078427 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioResponseHandler.java @@ -23,8 +23,6 @@ import java.io.IOException; import java.util.concurrent.Flow; -import static org.elasticsearch.core.Strings.format; - public class GoogleAiStudioResponseHandler extends BaseResponseHandler { static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down"; @@ -82,10 +80,6 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr } } - private static String resourceNotFoundError(Request request) { - return format("Resource not found at [%s]", request.getURI()); - } - @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); @@ -94,5 +88,4 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxResponseHandler.java index 82e05749967e3..19156d905ab66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxResponseHandler.java @@ -14,8 +14,6 @@ import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity; -import static org.elasticsearch.core.Strings.format; - public class IbmWatsonxResponseHandler extends BaseResponseHandler { public IbmWatsonxResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse); @@ -53,8 +51,4 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); } } - - private static String resourceNotFoundError(Request request) { - return format("Resource not found at [%s]", request.getURI()); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadModel.java new file mode 100644 index 0000000000000..4275e84610cc7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadModel.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.mixedbread.action.MixedbreadActionVisitor; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +/** + * Abstract class representing a Mixedbread model for inference. + * This class extends RateLimitGroupingModel and provides common functionality for Mixedbread models. + */ +public abstract class MixedbreadModel extends RateLimitGroupingModel { + private final URI uri; + + public MixedbreadModel(ModelConfigurations configurations, ModelSecrets secrets, @Nullable ApiKeySecrets apiKeySecrets, URI uri) { + super(configurations, secrets); + this.uri = uri; + } + + protected MixedbreadModel(MixedbreadModel model, TaskSettings taskSettings) { + super(model, taskSettings); + uri = model.uri(); + } + + public abstract ExecutableAction accept(MixedbreadActionVisitor creator, Map taskSettings); + + public URI uri() { + return uri; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return getServiceSettings().rateLimitSettings(); + } + + @Override + public MixedbreadRerankServiceSettings getServiceSettings() { + return (MixedbreadRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + public int rateLimitGroupingHash() { + return Objects.hash(getServiceSettings().modelId(), uri, getSecretSettings()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java new file mode 100644 index 0000000000000..d4adfe2acb016 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java @@ -0,0 +1,330 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ModelCreator; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.mixedbread.action.MixedbreadActionCreator; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModelCreator; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; + +/** + * Mixedbread inference service for reranking tasks. + * This service uses the Mixedbread REST API to perform document reranking. + */ +public class MixedbreadService extends SenderService implements RerankingInferenceService { + public static final String NAME = "mixedbread"; + public static final String SERVICE_NAME = "Mixedbread"; + + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.RERANK); + + /** + * {@link #rerankerWindowSize(String modelId)} method returns the size in words, not in tokens, so we'll need to translate + * tokens to words by multiplying by 0.75 and rounding down + + * The context window size for v1 models is 512 tokens / 300 words + * For v2 models it is from 8k / 5500 words to 32k / 22000 words + * tokens to words conversion reference + */ + private static final int DEFAULT_RERANKER_INPUT_SIZE_WORDS = 22000; + + private static final Map RERANKERS_INPUT_SIZE = Map.of( + "mixedbread-ai/mxbai-rerank-xsmall-v1", + 300, + "mixedbread-ai/mxbai-rerank-base-v1", + 300, + "mixedbread-ai/mxbai-rerank-large-v1", + 300 + ); + + private static final Map> MODEL_CREATORS = Map.of( + TaskType.RERANK, + new MixedbreadRerankModelCreator() + ); + + /** + * Constructor for creating an {@link MixedbreadService} with specified HTTP request sender factory and service components. + * + * @param factory the factory to create HTTP request senders + * @param serviceComponents the components required for the inference service + * @param context the context for the inference service factory + */ + public MixedbreadService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public MixedbreadService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + MixedbreadModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + serviceSettingsMap, + ConfigurationParseContext.REQUEST + ); + + ServiceUtils.throwIfNotEmptyMap(config, NAME); + ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, NAME); + ServiceUtils.throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private MixedbreadModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + ConfigurationParseContext.PERSISTENT + ); + } + + /** + * Creates an {@link MixedbreadModel} based on the provided parameters. + * + * @param inferenceId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param serviceSettings the settings for the inference service + * @param taskSettings the task-specific settings, if applicable + * @param chunkingSettings the settings for chunking, if applicable + * @param secretSettings the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + * @return a new instance of {@link MixedbreadModel} based on the provided parameters + */ + protected MixedbreadModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + ConfigurationParseContext context + ) { + return retrieveModelCreatorFromMapOrThrow(MODEL_CREATORS, inferenceId, taskType, NAME, context).createFromMaps( + inferenceId, + taskType, + NAME, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + context + ); + } + + @Override + public MixedbreadModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + return parsePersistedConfigWithSecrets(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, secretSettingsMap); + } + + @Override + public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { + return retrieveModelCreatorFromMapOrThrow( + MODEL_CREATORS, + config.getInferenceEntityId(), + config.getTaskType(), + config.getService(), + ConfigurationParseContext.PERSISTENT + ).createFromModelConfigurationsAndSecrets(config, secrets); + } + + @Override + public MixedbreadModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return parsePersistedConfigWithSecrets(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, null); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + @Override + protected void doChunkedInfer( + Model model, + List inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + throw new UnsupportedOperationException(Strings.format("%s service does not support chunked inference", NAME)); + } + + @Override + protected boolean supportsChunkedInfer() { + return false; + } + + @Override + public void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof MixedbreadModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + MixedbreadModel mixedbreadModel = (MixedbreadModel) model; + var actionCreator = new MixedbreadActionCreator(getSender(), getServiceComponents()); + + var action = mixedbreadModel.accept(actionCreator, taskSettings); + action.execute(inputs, timeout, listener); + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {} + + @Override + public TransportVersion getMinimalSupportedVersion() { + return MixedbreadUtils.INFERENCE_MIXEDBREAD_ADDED; + } + + @Override + public int rerankerWindowSize(String modelId) { + Integer inputSize = RERANKERS_INPUT_SIZE.get(modelId); + return inputSize != null ? inputSize : DEFAULT_RERANKER_INPUT_SIZE_WORDS; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The model ID to use for Mixedbread requests.") + .setLabel("Model ID") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadUtils.java new file mode 100644 index 0000000000000..977cc2eeb0203 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadUtils.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.TransportVersion; + +/** + * Utility class for Mixedbread related version checks. + */ +public final class MixedbreadUtils { + public static final String HOST = "api.mixedbread.com"; + public static final String VERSION_1 = "v1"; + public static final String RERANK_PATH = "reranking"; + public static URIBuilder DEFAULT_URI_BUILDER = new URIBuilder().setScheme("https").setHost(MixedbreadUtils.HOST); + + // common service settings fields + public static final String MODEL_FIELD = "model"; + + public static final String INPUT_FIELD = "input"; + + // rerank task settings fields + public static final String QUERY_FIELD = "query"; + + public static final String DOCUMENTS_FIELD = "documents"; + + // rerank task settings fields + public static final String RETURN_DOCUMENTS_FIELD = "return_input"; + public static final String TOP_K_FIELD = "top_k"; + + /** + * TransportVersion indicating when Mixedbread features were added. + */ + public static final TransportVersion INFERENCE_MIXEDBREAD_ADDED = TransportVersion.fromName("inference_mixedbread_added"); + + /** + * Checks if the given TransportVersion supports Mixedbread features. + * + * @param version the TransportVersion to check + * @return true if Mixedbread features are supported, false otherwise + */ + public static boolean supportsMixedbread(TransportVersion version) { + return version.supports(INFERENCE_MIXEDBREAD_ADDED); + } + + /** + * Private constructor to prevent instantiation. + */ + private MixedbreadUtils() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreator.java new file mode 100644 index 0000000000000..e54774ab2b576 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreator.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.mixedbread.request.rerank.MixedbreadRerankRequest; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadResponseHandler; +import org.elasticsearch.xpack.inference.services.mixedbread.response.MixedbreadRerankResponseEntity; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +public class MixedbreadActionCreator implements MixedbreadActionVisitor { + private static final String RERANK_ERROR_PREFIX = "Mixedbread rerank"; + + private static final ResponseHandler RERANK_HANDLER = new MixedbreadResponseHandler( + "mixedbread rerank", + (request, response) -> MixedbreadRerankResponseEntity.fromResponse(response) + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new MixedbreadActionCreator with the specified sender and service components. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public MixedbreadActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(MixedbreadRerankModel model, Map taskSettings) { + var overriddenModel = MixedbreadRerankModel.of(model, taskSettings); + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + inputs -> new MixedbreadRerankRequest( + inputs.getQuery(), + inputs.getChunks(), + inputs.getReturnDocuments(), + inputs.getTopN(), + model + ), + QueryAndDocsInputs.class + ); + var errorMessage = constructFailedToSendRequestMessage(RERANK_ERROR_PREFIX); + return new SenderExecutableAction(sender, manager, errorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionVisitor.java new file mode 100644 index 0000000000000..9b2604d5d6176 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionVisitor.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; + +import java.util.Map; + +/** + * Interface for creating {@link ExecutableAction} instances for Mixedbread models. + *

+ * This interface is used to create {@link ExecutableAction} instances for Mixedbread models + * {@link MixedbreadRerankModel}. + */ +public interface MixedbreadActionVisitor { + + /** + * Creates an {@link ExecutableAction} for the given {@link MixedbreadRerankModel}. + * + * @param model The model to create the action for. + * @return An {@link ExecutableAction} for the given model. + */ + ExecutableAction create(MixedbreadRerankModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/rerank/MixedbreadRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/rerank/MixedbreadRerankRequest.java new file mode 100644 index 0000000000000..5bc17ee35d859 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/rerank/MixedbreadRerankRequest.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request.rerank; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class MixedbreadRerankRequest implements Request { + private final MixedbreadRerankModel model; + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + + public MixedbreadRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + MixedbreadRerankModel model + ) { + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; + this.model = Objects.requireNonNull(model); + } + + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new MixedbreadRerankRequestEntity( + model.getServiceSettings().modelId(), + query, + input, + topN, + returnDocuments, + model.getTaskSettings() + ) + ).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // no truncation + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // no truncation + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/rerank/MixedbreadRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/rerank/MixedbreadRerankRequestEntity.java new file mode 100644 index 0000000000000..55616712b2cfd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/request/rerank/MixedbreadRerankRequestEntity.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadUtils; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record MixedbreadRerankRequestEntity( + String model, + String query, + List input, + @Nullable Integer topN, + @Nullable Boolean returnDocuments, + MixedbreadRerankTaskSettings taskSettings +) implements ToXContentObject { + + public MixedbreadRerankRequestEntity { + Objects.requireNonNull(model); + Objects.requireNonNull(query); + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MixedbreadUtils.MODEL_FIELD, model); + builder.field(MixedbreadUtils.QUERY_FIELD, query); + builder.field(MixedbreadUtils.INPUT_FIELD, input); + + if (topN != null) { + builder.field(MixedbreadUtils.TOP_K_FIELD, topN); + } else if (taskSettings.getTopN() != null) { + builder.field(MixedbreadUtils.TOP_K_FIELD, taskSettings.getTopN()); + } + + if (returnDocuments != null) { + builder.field(MixedbreadUtils.RETURN_DOCUMENTS_FIELD, returnDocuments); + } else if (taskSettings.getReturnDocuments() != null) { + builder.field(MixedbreadUtils.RETURN_DOCUMENTS_FIELD, taskSettings.getReturnDocuments()); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModel.java new file mode 100644 index 0000000000000..a29b522bf9262 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModel.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadModel; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadService; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadUtils; +import org.elasticsearch.xpack.inference.services.mixedbread.action.MixedbreadActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; + +public class MixedbreadRerankModel extends MixedbreadModel { + public static MixedbreadRerankModel of(MixedbreadRerankModel model, Map taskSettings) { + var requestTaskSettings = MixedbreadRerankTaskSettings.fromMap(taskSettings); + if (requestTaskSettings.isEmpty() || requestTaskSettings.equals(model.getTaskSettings())) { + return model; + } + return new MixedbreadRerankModel(model, MixedbreadRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public MixedbreadRerankModel( + String inferenceId, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + MixedbreadRerankServiceSettings.fromMap(serviceSettings, context), + MixedbreadRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets), + null + ); + } + + // should only be used for testing + MixedbreadRerankModel( + String inferenceId, + MixedbreadRerankServiceSettings serviceSettings, + MixedbreadRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings, + @Nullable String uri + ) { + super( + new ModelConfigurations(inferenceId, TaskType.RERANK, MixedbreadService.NAME, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + secretSettings, + Objects.requireNonNullElse( + ServiceUtils.createOptionalUri(uri), + buildUri( + MixedbreadService.SERVICE_NAME, + MixedbreadUtils.DEFAULT_URI_BUILDER.setPathSegments(MixedbreadUtils.VERSION_1, MixedbreadUtils.RERANK_PATH)::build + ) + ) + ); + } + + /** + * Constructor for creating an {@link MixedbreadRerankModel} from model configurations and secrets. + * + * @param modelConfigurations the configurations for the model + * @param modelSecrets the secret settings for the model + */ + public MixedbreadRerankModel(ModelConfigurations modelConfigurations, ModelSecrets modelSecrets) { + super(modelConfigurations, modelSecrets, (DefaultSecretSettings) modelSecrets.getSecretSettings(), null); + } + + public MixedbreadRerankModel(MixedbreadRerankModel model, MixedbreadRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public MixedbreadRerankServiceSettings getServiceSettings() { + return super.getServiceSettings(); + } + + @Override + public MixedbreadRerankTaskSettings getTaskSettings() { + return (MixedbreadRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return super.getSecretSettings(); + } + + /** + * Accepts a visitor to create an executable action. The returned action will not return documents in the response. + * @param visitor Interface for creating {@link ExecutableAction} instances for Mixedbread models. + * @param taskSettings Settings in the request to override the model's defaults + * @return the rerank action + */ + @Override + public ExecutableAction accept(MixedbreadActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModelCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModelCreator.java new file mode 100644 index 0000000000000..3f931e21a307c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModelCreator.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ModelCreator; + +import java.util.Map; + +/** + * Creates {@link MixedbreadRerankModel} instances from config maps + * or {@link ModelConfigurations} and {@link ModelSecrets} objects. + */ +public class MixedbreadRerankModelCreator implements ModelCreator { + @Override + public MixedbreadRerankModel createFromMaps( + String inferenceId, + TaskType taskType, + String service, + Map serviceSettings, + @Nullable Map taskSettings, + @Nullable ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + ConfigurationParseContext context + ) { + return new MixedbreadRerankModel(inferenceId, serviceSettings, taskSettings, secretSettings, context); + } + + @Override + public MixedbreadRerankModel createFromModelConfigurationsAndSecrets(ModelConfigurations config, ModelSecrets secrets) { + return new MixedbreadRerankModel(config, secrets); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettings.java new file mode 100644 index 0000000000000..7fc9547d334ac --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettings.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadService; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadUtils; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class MixedbreadRerankServiceSettings extends FilteredXContentObject implements ServiceSettings { + + public static final String NAME = "mixedbread_rerank_service_settings"; + + /** + * Free subscription tier 100 req / min + * Rate Limiting. + */ + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(100); + + public static MixedbreadRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + MixedbreadService.NAME, + context + ); + + validationException.throwIfValidationErrorsExist(); + + return new MixedbreadRerankServiceSettings(modelId, rateLimitSettings); + } + + private final String modelId; + + private final RateLimitSettings rateLimitSettings; + + public MixedbreadRerankServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = Objects.requireNonNull(modelId); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public MixedbreadRerankServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public String modelId() { + return modelId; + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + assert false : "should never be called when supportsVersion is used"; + return MixedbreadUtils.INFERENCE_MIXEDBREAD_ADDED; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return MixedbreadUtils.supportsMixedbread(version); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + MixedbreadRerankServiceSettings that = (MixedbreadRerankServiceSettings) object; + return Objects.equals(modelId, that.modelId()) && Objects.equals(rateLimitSettings, that.rateLimitSettings()); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettings.java new file mode 100644 index 0000000000000..a9c337e02c6a7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettings.java @@ -0,0 +1,157 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + +public class MixedbreadRerankTaskSettings implements TaskSettings { + public static final String NAME = "mixedbread_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; + public static final String TOP_N = "top_n"; + + public static final MixedbreadRerankTaskSettings EMPTY_SETTINGS = new MixedbreadRerankTaskSettings(null, null); + + public static MixedbreadRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); + Integer topN = extractOptionalPositiveInteger(map, TOP_N, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + if (returnDocuments == null && topN == null) { + return EMPTY_SETTINGS; + } + + return new MixedbreadRerankTaskSettings(topN, returnDocuments); + } + + /** + * Creates a new {@link MixedbreadRerankTaskSettings} + * by preferring non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link MixedbreadRerankTaskSettings} + */ + public static MixedbreadRerankTaskSettings of( + MixedbreadRerankTaskSettings originalSettings, + MixedbreadRerankTaskSettings requestTaskSettings + ) { + if (requestTaskSettings.isEmpty() || originalSettings.equals(requestTaskSettings)) { + return originalSettings; + } + return new MixedbreadRerankTaskSettings( + requestTaskSettings.getTopN() != null ? requestTaskSettings.getTopN() : originalSettings.getTopN(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments() + ); + } + + private final Integer topN; + private final Boolean returnDocuments; + + public MixedbreadRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalVInt(), in.readOptionalBoolean()); + } + + public MixedbreadRerankTaskSettings(@Nullable Integer topN, @Nullable Boolean doReturnDocuments) { + this.topN = topN; + this.returnDocuments = doReturnDocuments; + } + + @Override + public boolean isEmpty() { + return topN == null && returnDocuments == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topN != null) { + builder.field(TOP_N, topN); + } + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + assert false : "should never be called when supportsVersion is used"; + return MixedbreadUtils.INFERENCE_MIXEDBREAD_ADDED; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return MixedbreadUtils.supportsMixedbread(version); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(topN); + out.writeOptionalBoolean(returnDocuments); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MixedbreadRerankTaskSettings that = (MixedbreadRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topN, that.topN); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topN); + } + + public Integer getTopN() { + return topN; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + @Override + public MixedbreadRerankTaskSettings updatedTaskSettings(Map newSettings) { + MixedbreadRerankTaskSettings updatedSettings = MixedbreadRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return MixedbreadRerankTaskSettings.of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadResponseHandler.java new file mode 100644 index 0000000000000..92ab4d66e52ba --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadResponseHandler.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +public class MixedbreadResponseHandler extends BaseResponseHandler { + private static final String FORBIDDEN = "Valid credentials but insufficient permissions for this resource."; + private static final String PAYMENT_ERROR_MESSAGE = "Insufficient balance. Top up your account to continue."; + private static final String SERVICE_UNAVAILABLE = "Service temporarily down for maintenance or overloaded. Retry later."; + private static final String UNPROCESSABLE_ENTITY = "Request format is correct but cannot be processed."; + + public MixedbreadResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ErrorResponse::fromResponse); + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 400 } + */ + @Override + protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + if (result.isSuccessfulResponse()) { + return; + } + + // handle error codes + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 503) { + throw new RetryException(true, buildError(SERVICE_UNAVAILABLE, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 422) { + throw new RetryException(true, buildError(UNPROCESSABLE_ENTITY, request, result)); + } else if (statusCode == 404) { + throw new RetryException(false, buildError(resourceNotFoundError(request), request, result)); + } else if (statusCode == 403) { + throw new RetryException(false, buildError(FORBIDDEN, request, result)); + } else if (statusCode == 402) { + throw new RetryException(false, buildError(PAYMENT_ERROR_MESSAGE, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode == 400) { + throw new RetryException(false, buildError(BAD_REQUEST, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntity.java new file mode 100644 index 0000000000000..f125d37af53f2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntity.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.response; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class MixedbreadRerankResponseEntity { + + /** + * Parses the Mixedbread rerank response. + + * For a request like: + *

+     *{
+     *   "model": "mixedbread-ai/mxbai-rerank-xsmall-v1",
+     *   "query": "Who is the author of To Kill a Mockingbird?",
+     *   "input": [
+     *         "To Kill a Mockingbird is a novel by Harper Lee",
+     *         "The novel Moby-Dick was written by Herman Melville",
+     *         "Harper Lee, an American novelist",
+     *         "Jane Austen was an English novelist",
+     *         "The Harry Potter series written by British author J.K. Rowling",
+     *         "The Great Gatsby, a novel written by American author F. Scott Fitzgerald"
+     *     ],
+     *   "top_k": 3,
+     *   "return_input": false
+     * }
+     * 
+ *

+ * The response will look like (without whitespace): + *

+     *{
+     *     "usage": {
+     *         "prompt_tokens": 162,
+     *         "total_tokens": 162,
+     *         "completion_tokens": 0
+     *     },
+     *     "model": "mixedbread-ai/mxbai-rerank-xsmall-v1",
+     *     "data": [
+     *         {
+     *             "index": 0,
+     *             "score": 0.98291015625,
+     *             "input": null,
+     *             "object": "rank_result"
+     *         },
+     *         {
+     *             "index": 2,
+     *             "score": 0.61962890625,
+     *             "input": null,
+     *             "object": "rank_result"
+     *         },
+     *         {
+     *             "index": 3,
+     *             "score": 0.3642578125,
+     *             "input": null,
+     *             "object": "rank_result"
+     *         }
+     *     ],
+     *     "object": "list",
+     *     "top_k": 3,
+     *     "return_input": false
+     * }
+     * 
+ + * Parses the response from a Mixedbread rerank request and returns the results. + + * @param response the http response from Mixedbread + * @return the parsed response + * @throws IOException if there is an error parsing the response + */ + public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { + try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { + return Response.PARSER.apply(p, null).toRankedDocsResults(); + } + } + + private record Response(List results) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Response.class.getSimpleName(), + true, + args -> new Response((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("data")); + } + + public RankedDocsResults toRankedDocsResults() { + List rankedDocs = results.stream() + .map(item -> new RankedDocsResults.RankedDoc(item.index(), item.relevanceScore(), item.document())) + .toList(); + return new RankedDocsResults(rankedDocs); + } + } + + private record ResultItem(int index, float relevanceScore, @Nullable String document) { + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ResultItem.class.getSimpleName(), + true, + args -> new ResultItem((Integer) args[0], (Float) args[1], (String) args[2]) + ); + + static { + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareFloat(constructorArg(), new ParseField("score")); + PARSER.declareStringOrNull(optionalConstructorArg(), new ParseField("input")); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java index 9f7f32c366bd1..5151fd511a3a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java @@ -24,7 +24,6 @@ import java.util.concurrent.Flow; import java.util.function.Function; -import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown; public class OpenAiResponseHandler extends BaseResponseHandler { @@ -105,10 +104,6 @@ protected RetryException buildExceptionHandlingContentTooLarge(Request request, return new ContentTooLargeException(buildError(CONTENT_TOO_LARGE, request, result)); } - private static String resourceNotFoundError(Request request) { - return format("Resource not found at [%s]", request.getURI()); - } - protected RetryException buildExceptionHandling429(Request request, HttpResult result) { return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index dc00aa84e6880..ad5e9b1dff810 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -93,6 +93,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP public void testParseRequestConfig_CreatesACompletionModel() throws Exception { var parseRequestConfigTestConfig = testConfiguration.commonConfig(); + Assume.assumeTrue(testConfiguration.commonConfig().supportedTaskTypes().contains(TaskType.COMPLETION)); try (var service = parseRequestConfigTestConfig.createService(threadPool, clientManager)) { var config = getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceParameterizedModelCreationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceParameterizedModelCreationTests.java new file mode 100644 index 0000000000000..7af72c6db4589 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceParameterizedModelCreationTests.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedModelCreationTests; + +public class MixedbreadServiceParameterizedModelCreationTests extends AbstractInferenceServiceParameterizedModelCreationTests { + public MixedbreadServiceParameterizedModelCreationTests(TestCase testCase) { + super(MixedbreadServiceTests.createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceParameterizedParsingTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceParameterizedParsingTests.java new file mode 100644 index 0000000000000..1b90a6a0dc0c5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceParameterizedParsingTests.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceParameterizedParsingTests; + +public class MixedbreadServiceParameterizedParsingTests extends AbstractInferenceServiceParameterizedParsingTests { + public MixedbreadServiceParameterizedParsingTests(TestCase testCase) { + super(MixedbreadServiceTests.createTestConfiguration(), testCase); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceTests.java new file mode 100644 index 0000000000000..a780e77af11e6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceTests.java @@ -0,0 +1,1013 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModelTests; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.RERANK; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettingsTests.getTaskSettingsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; + +public class MixedbreadServiceTests extends AbstractInferenceServiceTests { + public static final String UNKNOWN_SETTINGS_EXCEPTION = + "Configuration contains settings [{extra_key=value}] unknown to the [mixedbread] service"; + public static final Boolean RETURN_DOCUMENTS_TRUE = true; + public static final Boolean RETURN_DOCUMENTS_FALSE = false; + public static final String DEFAULT_RERANK_URL = "https://api.mixedbread.com/v1/reranking"; + + private static final String INFERENCE_ID_VALUE = "id"; + private static final String MODEL_NAME_VALUE = "modelName"; + private static final String API_KEY = "secret"; + private static final String QUERY_VALUE = "query"; + private static final Integer TOP_N = 3; + private static final Integer REQUESTS_PER_MINUTE = 3; + private static final Boolean STREAM = false; + private static final List INPUT = List.of("candidate1", "candidate2", "candidate3"); + + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + public MixedbreadServiceTests() { + super(createTestConfiguration()); + } + + public static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder(new CommonConfig(RERANK, COMPLETION, EnumSet.of(RERANK)) { + + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return MixedbreadServiceTests.createService(threadPool, clientManager); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return MixedbreadRerankServiceSettingsTests.getServiceSettingsMap(MODEL_NAME_VALUE, null); + } + + @Override + protected ModelConfigurations createModelConfigurations(TaskType taskType) { + return switch (taskType) { + case RERANK -> new ModelConfigurations( + INFERENCE_ID_VALUE, + taskType, + MixedbreadService.NAME, + MixedbreadRerankServiceSettings.fromMap( + createServiceSettingsMap(taskType, ConfigurationParseContext.PERSISTENT), + ConfigurationParseContext.PERSISTENT + ), + MixedbreadRerankTaskSettings.EMPTY_SETTINGS + ); + // Completion is not supported, but in order to test unsupported task types it is included here + case COMPLETION -> new ModelConfigurations( + INFERENCE_ID_VALUE, + taskType, + MixedbreadService.NAME, + mock(ServiceSettings.class), + mock(TaskSettings.class) + ); + default -> throw new IllegalStateException("Unexpected value: " + taskType); + }; + } + + @Override + protected ModelSecrets createModelSecrets() { + return new ModelSecrets(DefaultSecretSettings.fromMap(createSecretSettingsMap())); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + return MixedbreadServiceTests.createServiceSettingsMap(taskType, parseContext); + } + + @Override + protected Map createTaskSettingsMap(TaskType taskType) { + if (taskType.equals(RERANK)) { + return MixedbreadRerankTaskSettingsTests.getTaskSettingsMap(null, null); + } + return createTaskSettingsMap(); + } + + @Override + protected Map createTaskSettingsMap() { + return new HashMap<>(); + } + + @Override + protected Map createSecretSettingsMap() { + return MixedbreadServiceTests.createSecretSettingsMap(); + } + + @Override + protected void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + MixedbreadServiceTests.assertModel(model, taskType, modelIncludesSecrets); + } + + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.noneOf(TaskType.class); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize(MODEL_NAME_VALUE), Matchers.is(22000)); + } + }).build(); + } + + private static void assertModel(Model model, TaskType taskType, boolean modelIncludesSecrets) { + if (Objects.requireNonNull(taskType) == RERANK) { + assertRerankModel(model, modelIncludesSecrets); + } else { + fail("unexpected task type [" + taskType + "]"); + } + } + + private static MixedbreadModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) { + assertThat(model, instanceOf(MixedbreadModel.class)); + + var mixedbreadModel = (MixedbreadModel) model; + assertThat(mixedbreadModel.getServiceSettings().modelId(), Matchers.is(MODEL_NAME_VALUE)); + if (modelIncludesSecrets) { + assertThat(mixedbreadModel.getSecretSettings().apiKey(), Matchers.is(new SecureString(API_KEY.toCharArray()))); + } + return mixedbreadModel; + } + + private static void assertRerankModel(Model model, boolean modelIncludesSecrets) { + var mixedbreadModel = assertCommonModelFields(model, modelIncludesSecrets); + assertThat(mixedbreadModel.getTaskSettings(), Matchers.is(MixedbreadRerankTaskSettings.EMPTY_SETTINGS)); + assertThat(mixedbreadModel.getTaskType(), Matchers.is(RERANK)); + } + + public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new MixedbreadService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private static Map createServiceSettingsMap(TaskType taskType, ConfigurationParseContext parseContext) { + return MixedbreadRerankServiceSettingsTests.getServiceSettingsMap(MODEL_NAME_VALUE, null); + } + + private static Map createSecretSettingsMap() { + return new HashMap<>(Map.of("api_key", API_KEY)); + } + + public void testBuildModelFromConfigAndSecrets_UnsupportedTaskType() throws IOException { + var modelConfigurations = new ModelConfigurations( + INFERENCE_ID_VALUE, + TaskType.COMPLETION, + MixedbreadService.NAME, + mock(ServiceSettings.class) + ); + try (var inferenceService = createInferenceService()) { + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> inferenceService.buildModelFromConfigAndSecrets(modelConfigurations, mock(ModelSecrets.class)) + ); + assertThat( + thrownException.getMessage(), + CoreMatchers.is( + org.elasticsearch.core.Strings.format( + """ + Failed to parse stored model [%s] for [%s] service, error: [The [%s] service does not support task type [%s]]. \ + Please delete and add the service again""", + INFERENCE_ID_VALUE, + MixedbreadService.NAME, + MixedbreadService.NAME, + TaskType.COMPLETION + ) + ) + ); + } + } + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_createsRerankModel() throws IOException { + try (var service = createMixedbreadService()) { + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + + var modelListener = new PlainActionFuture(); + + service.parseRequestConfig( + INFERENCE_ID_VALUE, + TaskType.RERANK, + getRequestConfigMap( + getServiceSettingsMap(modelName, requestsPerMinute), + getTaskSettingsMap(topN, returnDocuments), + getSecretSettingsMap(apiKey) + ), + modelListener + ); + + var rerankModel = (MixedbreadRerankModel) modelListener.actionGet(); + + assertThat(rerankModel.getSecretSettings().apiKey().toString(), is(apiKey)); + assertRerankModelSettings( + rerankModel, + modelName, + new RateLimitSettings(requestsPerMinute), + apiKey, + new MixedbreadRerankTaskSettings(topN, returnDocuments) + ); + } + } + + public void testParseRequestConfig_onlyRequiredSettings_createsRerankModel() throws IOException { + try (var service = createMixedbreadService()) { + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); + + var modelListener = new PlainActionFuture(); + + service.parseRequestConfig( + INFERENCE_ID_VALUE, + TaskType.RERANK, + getRequestConfigMap(getServiceSettingsMap(modelName), Map.of(), getSecretSettingsMap(apiKey)), + modelListener + ); + + var rerankModel = (MixedbreadRerankModel) modelListener.actionGet(); + + assertThat(rerankModel.getSecretSettings().apiKey().toString(), is(apiKey)); + assertRerankModelSettings( + modelListener.actionGet(), + modelName, + MixedbreadRerankServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS, + apiKey, + MixedbreadRerankTaskSettings.EMPTY_SETTINGS + ); + + } + } + + public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOException { + try (var service = createMixedbreadService()) { + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var apiKey = randomAlphanumericOfLength(8); + + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap(modelName, requestsPerMinute), + getTaskSettingsMap(topN, returnDocuments), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + INFERENCE_ID_VALUE, + TaskType.RERANK, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model.getSecretSettings().apiKey().toString(), is(apiKey)); + assertRerankModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + apiKey, + new MixedbreadRerankTaskSettings(topN, returnDocuments) + ); + } + } + + public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsRerankModel() throws IOException { + try (var service = createMixedbreadService()) { + var modelName = randomAlphanumericOfLength(8); + var apiKey = randomAlphanumericOfLength(8); + + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), Map.of(), getSecretSettingsMap(apiKey)); + + var model = service.parsePersistedConfigWithSecrets( + INFERENCE_ID_VALUE, + TaskType.RERANK, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model.getSecretSettings().apiKey().toString(), is(apiKey)); + assertRerankModelSettings( + model, + modelName, + MixedbreadRerankServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS, + apiKey, + MixedbreadRerankTaskSettings.EMPTY_SETTINGS + ); + } + } + + public void testParsePersistedConfig_createsRerankModel() throws IOException { + try (var service = createMixedbreadService()) { + var modelName = randomAlphanumericOfLength(8); + var requestsPerMinute = randomNonNegativeInt(); + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap(modelName, requestsPerMinute), + getTaskSettingsMap(topN, returnDocuments), + null + ); + + var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.RERANK, persistedConfig.config()); + + assertRerankModelSettings( + model, + modelName, + new RateLimitSettings(requestsPerMinute), + "", + new MixedbreadRerankTaskSettings(topN, returnDocuments) + ); + } + } + + public void testParseRequestConfig_NoModelId_ThrowsException() throws IOException { + try (var service = createMixedbreadService()) { + ActionListener modelListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + Matchers.is("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + ); + + service.parseRequestConfig( + INFERENCE_ID_VALUE, + TaskType.RERANK, + getRequestConfigMap( + getServiceSettingsMap(null, REQUESTS_PER_MINUTE), + getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE), + getSecretSettingsMap(API_KEY) + ), + modelListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException { + try (var service = createMixedbreadService()) { + var secretSettings = getSecretSettingsMap(API_KEY); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE), + getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE), + secretSettings + ); + + assertThrowsExceptionWhenAnExtraKeyExists(service, config); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException { + try (var service = createMixedbreadService()) { + var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + serviceSettings, + getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE), + getSecretSettingsMap(API_KEY) + ); + + assertThrowsExceptionWhenAnExtraKeyExists(service, config); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException { + try (var service = createMixedbreadService()) { + var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); + taskSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE), + taskSettings, + getSecretSettingsMap(API_KEY) + ); + + assertThrowsExceptionWhenAnExtraKeyExists(service, config); + } + } + + private static void assertThrowsExceptionWhenAnExtraKeyExists(MixedbreadService service, Map config) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), Matchers.is(UNKNOWN_SETTINGS_EXCEPTION)); + } + ); + + service.parseRequestConfig(INFERENCE_ID_VALUE, RERANK, config, modelVerificationListener); + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); + var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); + var secretSettings = getSecretSettingsMap(API_KEY); + + try (var service = createMixedbreadService()) { + secretSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, RERANK, persistedConfig.config()); + + assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); + + var rerankModel = (MixedbreadRerankModel) model; + assertThat(rerankModel.getServiceSettings().modelId(), is(MODEL_NAME_VALUE)); + assertThat(rerankModel.getTaskSettings(), is(new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE))); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); + var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); + var secretSettings = getSecretSettingsMap(API_KEY); + + try (var service = createMixedbreadService()) { + serviceSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, RERANK, persistedConfig.config()); + + assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); + + var rerankModel = (MixedbreadRerankModel) model; + assertThat(rerankModel.getServiceSettings().modelId(), is(MODEL_NAME_VALUE)); + assertThat(rerankModel.getTaskSettings(), is(new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE))); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); + var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); + var secretSettings = getSecretSettingsMap(API_KEY); + + try (var service = createMixedbreadService()) { + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, RERANK, persistedConfig.config()); + + assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); + + var rerankModel = (MixedbreadRerankModel) model; + assertThat(rerankModel.getServiceSettings().modelId(), is(MODEL_NAME_VALUE)); + assertThat(rerankModel.getTaskSettings(), is(new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE))); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); + var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); + var secretSettings = getSecretSettingsMap(API_KEY); + + try (var service = createMixedbreadService()) { + secretSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets( + INFERENCE_ID_VALUE, + RERANK, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); + + var rerankModel = (MixedbreadRerankModel) model; + assertThat(rerankModel.getServiceSettings().modelId(), is(MODEL_NAME_VALUE)); + assertThat(rerankModel.getTaskSettings(), is(new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE))); + assertThat(rerankModel.getSecretSettings().apiKey(), is(API_KEY)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); + var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); + var secretSettings = getSecretSettingsMap(API_KEY); + + try (var service = createMixedbreadService()) { + serviceSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets( + INFERENCE_ID_VALUE, + RERANK, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); + + var rerankModel = (MixedbreadRerankModel) model; + assertThat(rerankModel.getServiceSettings().modelId(), is(MODEL_NAME_VALUE)); + assertThat(rerankModel.getTaskSettings(), is(new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE))); + assertThat(rerankModel.getSecretSettings().apiKey(), is(API_KEY)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); + var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); + var secretSettings = getSecretSettingsMap(API_KEY); + + try (var service = createMixedbreadService()) { + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets( + INFERENCE_ID_VALUE, + RERANK, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); + + var rerankModel = (MixedbreadRerankModel) model; + assertThat(rerankModel.getServiceSettings().modelId(), is(MODEL_NAME_VALUE)); + assertThat(rerankModel.getTaskSettings(), is(new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE))); + assertThat(rerankModel.getSecretSettings().apiKey(), is(API_KEY)); + } + } + + public void testInfer_Rerank_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MixedbreadService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = MixedbreadRerankModelTests.createModel(MODEL_NAME_VALUE, API_KEY, TOP_N, RETURN_DOCUMENTS_FALSE, getUrl(webServer)); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + QUERY_VALUE, + null, + null, + List.of("candidate1", "candidate2"), + STREAM, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TEST_REQUEST_TIMEOUT)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Unauthorized")); + assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testInfer_Rerank_NoReturnDocuments_NoTopN() throws IOException { + String responseJson = """ + { + "usage": { + "prompt_tokens": 162, + "total_tokens": 162, + "completion_tokens": 0 + }, + "model": "modelName", + "data": [ + { + "index": 0, + "score": 0.98291015625, + "object": "rank_result" + }, + { + "index": 2, + "score": 0.61962890625, + "object": "rank_result" + }, + { + "index": 3, + "score": 0.3642578125, + "object": "rank_result" + } + ], + "object": "list", + "return_input": false + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MixedbreadService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = MixedbreadRerankModelTests.createModel(MODEL_NAME_VALUE, API_KEY, null, RETURN_DOCUMENTS_FALSE, getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + QUERY_VALUE, + null, + null, + INPUT, + STREAM, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("index", 0, "relevance_score", 0.98291016F)), + Map.of("ranked_doc", Map.of("index", 2, "relevance_score", 0.6196289F)), + Map.of("ranked_doc", Map.of("index", 3, "relevance_score", 0.3642578F)) + ) + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat( + requestMap, + is(Map.of("query", QUERY_VALUE, "input", INPUT, "model", MODEL_NAME_VALUE, "return_input", RETURN_DOCUMENTS_FALSE)) + ); + } + } + + public void testInfer_Rerank_ReturnDocumentsNull_NoTopN() throws IOException { + String responseJson = """ + { + "usage": { + "prompt_tokens": 162, + "total_tokens": 162, + "completion_tokens": 0 + }, + "model": "modelName", + "data": [ + { + "index": 0, + "score": 0.98291015625, + "input": "candidate3", + "object": "rank_result" + }, + { + "index": 2, + "score": 0.61962890625, + "input": "candidate2", + "object": "rank_result" + }, + { + "index": 3, + "score": 0.3642578125, + "input": "candidate1", + "object": "rank_result" + } + ], + "object": "list" + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MixedbreadService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = MixedbreadRerankModelTests.createModel(MODEL_NAME_VALUE, API_KEY, null, null, getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + QUERY_VALUE, + null, + null, + INPUT, + STREAM, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("index", 0, "relevance_score", 0.98291015625F, "text", "candidate3")), + Map.of("ranked_doc", Map.of("index", 2, "relevance_score", 0.61962890625F, "text", "candidate2")), + Map.of("ranked_doc", Map.of("index", 3, "relevance_score", 0.3642578125F, "text", "candidate1")) + ) + ) + ) + ); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat(requestMap, is(Map.of("query", QUERY_VALUE, "input", INPUT, "model", MODEL_NAME_VALUE))); + + } + } + + public void testInfer_Rerank_ReturnDocuments_TopN() throws IOException { + String responseJson = """ + { + "usage": { + "prompt_tokens": 162, + "total_tokens": 162, + "completion_tokens": 0 + }, + "model": "modelName", + "data": [ + { + "index": 0, + "score": 0.98291015625, + "input": "candidate3", + "object": "rank_result" + }, + { + "index": 2, + "score": 0.61962890625, + "input": "candidate2", + "object": "rank_result" + }, + { + "index": 3, + "score": 0.3642578125, + "input": "candidate1", + "object": "rank_result" + } + ], + "object": "list", + "top_k": 3, + "return_input": true + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MixedbreadService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = MixedbreadRerankModelTests.createModel(MODEL_NAME_VALUE, API_KEY, TOP_N, RETURN_DOCUMENTS_TRUE, getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + QUERY_VALUE, + null, + null, + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + STREAM, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TEST_REQUEST_TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("text", "candidate3", "index", 0, "relevance_score", 0.98291015625F)), + Map.of("ranked_doc", Map.of("text", "candidate2", "index", 2, "relevance_score", 0.61962890625F)), + Map.of("ranked_doc", Map.of("text", "candidate1", "index", 3, "relevance_score", 0.3642578125F)) + ) + ) + ) + ); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat( + requestMap, + is( + Map.of( + "query", + QUERY_VALUE, + "input", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + "model", + MODEL_NAME_VALUE, + "return_input", + RETURN_DOCUMENTS_TRUE, + "top_k", + 3 + ) + ) + ); + + } + } + + public void testGetConfiguration() throws Exception { + try (var service = createMixedbreadService()) { + String content = XContentHelper.stripWhitespace(""" + { + "service": "mixedbread", + "name": "Mixedbread", + "task_types": ["rerank"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["rerank"] + }, + "model_id": { + "description": "The model ID to use for Mixedbread requests.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["rerank"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["rerank"] + } + } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + private static void assertRerankModelSettings( + Model model, + String modelName, + RateLimitSettings rateLimitSettings, + String apiKey, + MixedbreadRerankTaskSettings taskSettings + ) { + assertThat(model, instanceOf(MixedbreadRerankModel.class)); + + var rerankModel = (MixedbreadRerankModel) model; + assertCommonModelSettings(rerankModel, DEFAULT_RERANK_URL, modelName, rateLimitSettings, apiKey); + + assertThat(rerankModel.getTaskSettings(), is(taskSettings)); + } + + private static void assertCommonModelSettings( + T model, + String url, + String modelName, + RateLimitSettings rateLimitSettings, + String apiKey + ) { + assertThat(model.uri().toString(), is(url)); + assertThat(model.getServiceSettings().modelId(), is(modelName)); + assertThat(model.rateLimitSettings(), is(rateLimitSettings)); + } + + private MixedbreadService createMixedbreadService() { + return new MixedbreadService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + @Override + public InferenceService createInferenceService() { + return createMixedbreadService(); + } + + @Override + protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) { + assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(22000)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreatorTests.java new file mode 100644 index 0000000000000..d45bbb4a88d4b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/action/MixedbreadActionCreatorTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class MixedbreadActionCreatorTests extends ESTestCase { + private static final String EXPECTED_EXCEPTION = "Failed to send Mixedbread rerank request. Cause: failed"; + private static final QueryAndDocsInputs QUERY_AND_DOCS_INPUTS = new QueryAndDocsInputs( + "popular name", + List.of("Luke"), + false, + 3, + false + ); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction("model", "secret", null, null, sender); + ElasticsearchException thrownException = executeActionWithException(action); + + ESTestCase.assertThat(thrownException.getMessage(), is(EXPECTED_EXCEPTION)); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("model", "secret", null, null, sender); + ElasticsearchException thrownException = executeActionWithException(action); + + ESTestCase.assertThat(thrownException.getMessage(), is(EXPECTED_EXCEPTION)); + } + + private static ElasticsearchException executeActionWithException(ExecutableAction action) { + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(QUERY_AND_DOCS_INPUTS, ESTestCase.TEST_REQUEST_TIMEOUT, listener); + return expectThrows(ElasticsearchException.class, () -> listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT)); + } + + private ExecutableAction createAction(String modelName, String apiKey, Integer topN, Boolean returnDocuments, Sender sender) { + var actionCreator = new MixedbreadActionCreator(sender, createWithEmptySettings(threadPool)); + var model = MixedbreadRerankModelTests.createModel(modelName, apiKey, topN, returnDocuments, null); + return actionCreator.create(model, null); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntityTests.java new file mode 100644 index 0000000000000..d7ac7a9331a26 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestEntityTests.java @@ -0,0 +1,169 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.mixedbread.request.rerank.MixedbreadRerankRequestEntity; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class MixedbreadRerankRequestEntityTests extends ESTestCase { + + public static final String MODEL = "model"; + public static final String QUERY = "query"; + + public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + MODEL, + QUERY, + List.of("abc"), + 12, + Boolean.TRUE, + new MixedbreadRerankTaskSettings(8, Boolean.FALSE) + ); + + assertThat(getXContentResult(entity), equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc" + ], + "top_k": 12, + "return_input": true + } + """)); + } + + public void testXContent_SingleRequest_WritesMinimalFields() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + MODEL, + QUERY, + List.of("abc"), + null, + null, + new MixedbreadRerankTaskSettings(null, null) + ); + + assertThat(getXContentResult(entity), equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc" + ] + } + """)); + } + + public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + MODEL, + QUERY, + List.of("abc", "def"), + 12, + Boolean.FALSE, + new MixedbreadRerankTaskSettings(8, Boolean.TRUE) + ); + + assertThat(getXContentResult(entity), equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc", + "def" + ], + "top_k": 12, + "return_input": false + } + """)); + } + + public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + MODEL, + QUERY, + List.of("abc", "def"), + null, + null, + new MixedbreadRerankTaskSettings(null, null) + ); + + assertThat(getXContentResult(entity), equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc", + "def" + ] + } + """)); + } + + public void testXContent_SingleRequest_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + MODEL, + QUERY, + List.of("abc"), + null, + null, + new MixedbreadRerankTaskSettings(8, Boolean.FALSE) + ); + + assertThat(getXContentResult(entity), equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc" + ], + "top_k": 8, + "return_input": false + } + """)); + } + + public void testXContent_SingleRequest_UsesTaskSettingsReturnDocumentsIfRootIsNotDefined() throws IOException { + var entity = new MixedbreadRerankRequestEntity( + MODEL, + QUERY, + List.of("abc"), + null, + null, + new MixedbreadRerankTaskSettings(8, Boolean.TRUE) + ); + + assertThat(getXContentResult(entity), equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "input": [ + "abc" + ], + "top_k": 8, + "return_input": true + } + """)); + } + + private String getXContentResult(MixedbreadRerankRequestEntity entity) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + return Strings.toString(builder); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestTests.java new file mode 100644 index 0000000000000..977dc533de21d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/request/MixedbreadRerankRequestTests.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.mixedbread.request.rerank.MixedbreadRerankRequest; +import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadServiceTests.RETURN_DOCUMENTS_FALSE; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class MixedbreadRerankRequestTests extends ESTestCase { + + private static final String API_KEY = "secret"; + public static final String INPUT = "input_value"; + public static final String MODEL = "model_id_value"; + public static final String QUERY = "query_value"; + public static final int TOP_K = 1; + + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { + var request = createRequest(QUERY, INPUT, MODEL, null, null); + var requestMap = getEntityAsMap(request); + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("input"), is(List.of(INPUT))); + assertThat(requestMap.get("query"), is(QUERY)); + assertThat(requestMap.get("model"), is(MODEL)); + } + + public void testCreateRequest_WithAllFieldsSet() throws IOException { + var request = createRequest(QUERY, INPUT, MODEL, TOP_K, RETURN_DOCUMENTS_FALSE); + Map requestMap = getEntityAsMap(request); + assertThat(requestMap, aMapWithSize(5)); + assertThat(requestMap.get("input"), is(List.of(INPUT))); + assertThat(requestMap.get("query"), is(QUERY)); + assertThat(requestMap.get("top_k"), is(TOP_K)); + assertThat(requestMap.get("return_input"), is(RETURN_DOCUMENTS_FALSE)); + assertThat(requestMap.get("model"), is(MODEL)); + } + + public void testTruncate_DoesNotTruncate() { + var request = createRequest(QUERY, INPUT, "null", null, null); + var truncatedRequest = request.truncate(); + + assertThat(truncatedRequest, sameInstance(request)); + } + + private static MixedbreadRerankRequest createRequest( + String query, + String input, + @Nullable String modelId, + @Nullable Integer topN, + @Nullable Boolean returnDocuments + ) { + var rerankModel = MixedbreadRerankModelTests.createModel(modelId, API_KEY, null, null, null); + return new MixedbreadRerankRequest(query, List.of(input), returnDocuments, topN, rerankModel); + } + + private Map getEntityAsMap(MixedbreadRerankRequest request) throws IOException { + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + assertThat(httpPost.getURI(), is(sameInstance(request.getURI()))); + return entityAsMap(httpPost.getEntity().getContent()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModelTests.java new file mode 100644 index 0000000000000..66fa58c414851 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankModelTests.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankTaskSettingsTests.getTaskSettingsMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class MixedbreadRerankModelTests extends ESTestCase { + + public static final String DEFAULT_URL = "https://api.mixedbread.com/v1/reranking"; + public static final String CUSTOM_URL = "https://custom.url.com/v1/rerank"; + public static final String MODEL_ID = "model_id_value"; + public static final String API_KEY = "secret"; + + public void testConstructor_usesDefaultUrlWhenNull() { + var model = createModel(MODEL_ID, API_KEY, null, null, null); + assertThat(model.uri().toString(), is(DEFAULT_URL)); + } + + public void testConstructor_usesUrlWhenSpecified() { + var model = createModel(MODEL_ID, API_KEY, null, null, CUSTOM_URL); + assertThat(model.uri().toString(), is(CUSTOM_URL)); + } + + public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty() { + var model = createModel(MODEL_ID, API_KEY, 10, true, CUSTOM_URL); + var overriddenModel = MixedbreadRerankModel.of(model, Map.of()); + assertThat(overriddenModel, sameInstance(model)); + } + + public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull() { + var model = createModel(MODEL_ID, API_KEY, 10, true, CUSTOM_URL); + var overriddenModel = MixedbreadRerankModel.of(model, null); + assertThat(overriddenModel, sameInstance(model)); + } + + public void testOf_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEqual() { + var topN = randomNonNegativeInt(); + var returnDocuments = randomBoolean(); + var model = createModel(MODEL_ID, API_KEY, topN, returnDocuments, CUSTOM_URL); + var overriddenModel = MixedbreadRerankModel.of(model, getTaskSettingsMap(topN, returnDocuments)); + assertThat(overriddenModel, sameInstance(model)); + } + + public void testOf_SetsTopN_FromRequestTaskSettings_OverridingStoredTaskSettings() { + var model = createModel(MODEL_ID, API_KEY, 15, null, CUSTOM_URL); + var topNFromRequest = 10; + var overriddenModel = MixedbreadRerankModel.of(model, getTaskSettingsMap(topNFromRequest, null)); + var expectedModel = createModel(MODEL_ID, API_KEY, topNFromRequest, null, CUSTOM_URL); + assertThat(overriddenModel, is(expectedModel)); + } + + public void testOf_SetsReturnDocuments_FromRequestTaskSettings() { + var topN = 15; + var model = createModel(MODEL_ID, API_KEY, topN, true, CUSTOM_URL); + var returnDocumentsFromRequest = false; + var overriddenModel = MixedbreadRerankModel.of(model, getTaskSettingsMap(null, returnDocumentsFromRequest)); + var expectedModel = createModel(MODEL_ID, API_KEY, topN, returnDocumentsFromRequest, CUSTOM_URL); + assertThat(overriddenModel, is(expectedModel)); + } + + public static MixedbreadRerankModel createModel( + String model, + String apiKey, + @Nullable Integer topN, + @Nullable Boolean returnDocuments, + @Nullable String uri + ) { + return new MixedbreadRerankModel( + model, + new MixedbreadRerankServiceSettings(model, null), + new MixedbreadRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())), + uri + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettingsTests.java new file mode 100644 index 0000000000000..c9353d4a0bc61 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankServiceSettingsTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.elasticsearch.xpack.inference.services.settings.RateLimitSettings.REQUESTS_PER_MINUTE_FIELD; + +public class MixedbreadRerankServiceSettingsTests extends AbstractWireSerializingTestCase { + private static final String MODEL = "model"; + private static final RateLimitSettings RATE_LIMIT = new RateLimitSettings(2); + + public static MixedbreadRerankServiceSettings createRandom() { + return createRandom(randomFrom(new RateLimitSettings[] { null, RateLimitSettingsTests.createRandom() })); + } + + public static MixedbreadRerankServiceSettings createRandom(@Nullable RateLimitSettings rateLimitSettings) { + return new MixedbreadRerankServiceSettings(randomAlphaOfLength(10), rateLimitSettings); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new MixedbreadRerankServiceSettings(MODEL, RATE_LIMIT); + assertThat(getXContentResult(serviceSettings), equalToIgnoringWhitespaceInJsonString(""" + { + "model_id":"model", + "rate_limit": { + "requests_per_minute": 2 + } + } + """)); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = new MixedbreadRerankServiceSettings(MODEL, null); + assertThat(getXContentResult(serviceSettings), equalToIgnoringWhitespaceInJsonString(""" + { + "model_id":"model", + "rate_limit": { + "requests_per_minute": 100 + } + } + """)); + } + + private String getXContentResult(MixedbreadRerankServiceSettings serviceSettings) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + return Strings.toString(builder); + } + + @Override + protected Writeable.Reader instanceReader() { + return MixedbreadRerankServiceSettings::new; + } + + @Override + protected MixedbreadRerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected MixedbreadRerankServiceSettings mutateInstance(MixedbreadRerankServiceSettings instance) throws IOException { + var modelId = instance.modelId(); + var rateLimitSettings = instance.rateLimitSettings(); + switch (randomInt(1)) { + case 0 -> modelId = randomValueOtherThan(modelId, () -> randomAlphaOfLength(8)); + case 1 -> rateLimitSettings = randomValueOtherThan(rateLimitSettings, RateLimitSettingsTests::createRandom); + default -> throw new AssertionError("Illegal randomisation branch"); + } + return new MixedbreadRerankServiceSettings(modelId, rateLimitSettings); + } + + public static Map getServiceSettingsMap(String model) { + return getServiceSettingsMap(model, null); + } + + public static Map getServiceSettingsMap(String model, @Nullable Integer requestsPerMinute) { + var map = new HashMap(); + + map.put(ServiceFields.MODEL_ID, model); + + if (requestsPerMinute != null) { + map.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(REQUESTS_PER_MINUTE_FIELD, requestsPerMinute))); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..4de364cf95f6b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/rerank/MixedbreadRerankTaskSettingsTests.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadServiceTests.RETURN_DOCUMENTS_FALSE; +import static org.elasticsearch.xpack.inference.services.mixedbread.MixedbreadServiceTests.RETURN_DOCUMENTS_TRUE; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class MixedbreadRerankTaskSettingsTests extends AbstractWireSerializingTestCase { + private static final int TOP_N = 7; + private static final int TOP_N_UPDATE_VALUE = 8; + + public static MixedbreadRerankTaskSettings createRandom() { + var returnDocuments = randomOptionalBoolean(); + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + + return new MixedbreadRerankTaskSettings(topNDocsOnly, returnDocuments); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + Map taskMap = Map.of( + MixedbreadRerankTaskSettings.RETURN_DOCUMENTS, + RETURN_DOCUMENTS_TRUE, + MixedbreadRerankTaskSettings.TOP_N, + TOP_N + ); + var settings = MixedbreadRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + assertTrue(settings.getReturnDocuments()); + assertEquals(TOP_N, settings.getTopN().intValue()); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = MixedbreadRerankTaskSettings.fromMap(Map.of()); + assertNull(settings.getReturnDocuments()); + assertNull(settings.getTopN()); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + MixedbreadRerankTaskSettings.RETURN_DOCUMENTS, + "invalid", + MixedbreadRerankTaskSettings.TOP_N, + TOP_N + ); + var thrownException = expectThrows(ValidationException.class, () -> MixedbreadRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + MixedbreadRerankTaskSettings.RETURN_DOCUMENTS, + RETURN_DOCUMENTS_TRUE, + MixedbreadRerankTaskSettings.TOP_N, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> MixedbreadRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + } + + public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE); + MixedbreadRerankTaskSettings updatedSettings = initialSettings.updatedTaskSettings(Map.of()); + assertThat(initialSettings, is(sameInstance(updatedSettings))); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE); + Map newSettings = Map.of(MixedbreadRerankTaskSettings.RETURN_DOCUMENTS, RETURN_DOCUMENTS_FALSE); + MixedbreadRerankTaskSettings updatedSettings = initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getTopN(), updatedSettings.getTopN()); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE); + Map newSettings = Map.of(MixedbreadRerankTaskSettings.TOP_N, TOP_N_UPDATE_VALUE); + MixedbreadRerankTaskSettings updatedSettings = initialSettings.updatedTaskSettings(newSettings); + assertEquals(TOP_N_UPDATE_VALUE, updatedSettings.getTopN().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new MixedbreadRerankTaskSettings(TOP_N, RETURN_DOCUMENTS_TRUE); + Map newSettings = Map.of( + MixedbreadRerankTaskSettings.RETURN_DOCUMENTS, + RETURN_DOCUMENTS_FALSE, + MixedbreadRerankTaskSettings.TOP_N, + TOP_N_UPDATE_VALUE + ); + MixedbreadRerankTaskSettings updatedSettings = initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(TOP_N_UPDATE_VALUE, updatedSettings.getTopN().intValue()); + } + + @Override + protected Writeable.Reader instanceReader() { + return MixedbreadRerankTaskSettings::new; + } + + @Override + protected MixedbreadRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected MixedbreadRerankTaskSettings mutateInstance(MixedbreadRerankTaskSettings instance) throws IOException { + var topNDocsOnly = instance.getTopN(); + var returnDocuments = instance.getReturnDocuments(); + switch (randomInt(1)) { + case 0 -> topNDocsOnly = randomValueOtherThan(topNDocsOnly, () -> randomFrom(randomIntBetween(1, 10), null)); + case 1 -> returnDocuments = returnDocuments == null ? randomBoolean() : returnDocuments == false; + } + return new MixedbreadRerankTaskSettings(topNDocsOnly, returnDocuments); + } + + public static Map getTaskSettingsMap(@Nullable Integer topN, Boolean returnDocuments) { + var map = new HashMap(); + + if (topN != null) { + map.put(MixedbreadRerankTaskSettings.TOP_N, topN); + } + + if (returnDocuments != null) { + map.put(MixedbreadRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntityTests.java new file mode 100644 index 0000000000000..2d21fdc5e4b97 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/response/MixedbreadRerankResponseEntityTests.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mixedbread.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class MixedbreadRerankResponseEntityTests extends ESTestCase { + + private static final String HARPER_LEE = "Harper Lee, an American novelist"; + private static final String NOVEL_BY_HARPER_LEE = "To Kill a Mockingbird is a novel by Harper Lee"; + private static final String JANE_AUSTEN = "Jane Austen was an English novelist"; + + private static final List RESPONSE_LITERAL_DOCS = List.of( + new RankedDocsResults.RankedDoc(0, 0.98291015625F, null), + new RankedDocsResults.RankedDoc(2, 0.61962890625F, null), + new RankedDocsResults.RankedDoc(3, 0.3642578125F, null) + ); + + private static final List RESPONSE_LITERAL_DOCS_WITH_TEXT = List.of( + new RankedDocsResults.RankedDoc(0, 0.98291015625F, HARPER_LEE), + new RankedDocsResults.RankedDoc(2, 0.61962890625F, NOVEL_BY_HARPER_LEE), + new RankedDocsResults.RankedDoc(3, 0.3642578125F, JANE_AUSTEN) + ); + + public void testResponseLiteral() throws IOException { + + InferenceServiceResults parsedResults = MixedbreadRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), RESPONSE_LITERAL.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(RESPONSE_LITERAL_DOCS)); + } + + public void testResponseLiteralWithDocumentsAsString() throws IOException { + InferenceServiceResults parsedResults = MixedbreadRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), RESPONSE_LITERAL_WITH_INPUT.getBytes(StandardCharsets.UTF_8)) + ); + assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(RESPONSE_LITERAL_DOCS_WITH_TEXT)); + } + + private static final String RESPONSE_LITERAL = """ + { + "usage": { + "prompt_tokens": 162, + "total_tokens": 162, + "completion_tokens": 0 + }, + "model": "mixedbread-ai/mxbai-rerank-xsmall-v1", + "data": [ + { + "index": 0, + "score": 0.98291015625, + "object": "rank_result" + }, + { + "index": 2, + "score": 0.61962890625, + "object": "rank_result" + }, + { + "index": 3, + "score": 0.3642578125, + "object": "rank_result" + } + ], + "object": "list", + "top_k": 3, + "return_input": false + } + """; + + private static final String RESPONSE_LITERAL_WITH_INPUT = Strings.format(""" + { + "usage": { + "prompt_tokens": 162, + "total_tokens": 162, + "completion_tokens": 0 + }, + "model": "mixedbread-ai/mxbai-rerank-xsmall-v1", + "data": [ + { + "index": 0, + "score": 0.98291015625, + "input": "%s", + "object": "rank_result" + }, + { + "index": 2, + "score": 0.61962890625, + "input": "%s", + "object": "rank_result" + }, + { + "index": 3, + "score": 0.3642578125, + "input": "%s", + "object": "rank_result" + } + ], + "object": "list", + "top_k": 3, + "return_input": false + } + """, HARPER_LEE, NOVEL_BY_HARPER_LEE, JANE_AUSTEN); +}