diff --git a/docs/changelog/111181.yaml b/docs/changelog/111181.yaml new file mode 100644 index 0000000000000..7f9f5937b7652 --- /dev/null +++ b/docs/changelog/111181.yaml @@ -0,0 +1,5 @@ +pr: 111181 +summary: "[Inference API] Add Alibaba Cloud AI Search Model support to Inference API" +area: Machine Learning +type: enhancement +issues: [ ] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 41fa34bb5a4a3..c68a33c6df6c4 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -200,6 +200,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_ES_FIELD_CACHED_SERIALIZATION = def(8_730_00_0); public static final TransportVersion ADD_MANAGE_ROLES_PRIVILEGE = def(8_731_00_0); public static final TransportVersion REPOSITORIES_TELEMETRY = def(8_732_00_0); + public static final TransportVersion ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED = def(8_733_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 489a81b642492..d4810ba930b44 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -25,6 +25,13 @@ import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; @@ -117,6 +124,7 @@ public static List getNamedWriteables() { addAnthropicNamedWritables(namedWriteables); addAmazonBedrockNamedWriteables(namedWriteables); addEisNamedWriteables(namedWriteables); + addAlibabaCloudSearchNamedWriteables(namedWriteables); return namedWriteables; } @@ -482,6 +490,59 @@ private static void addAnthropicNamedWritables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AlibabaCloudSearchServiceSettings.NAME, + AlibabaCloudSearchServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AlibabaCloudSearchEmbeddingsServiceSettings.NAME, + AlibabaCloudSearchEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AlibabaCloudSearchEmbeddingsTaskSettings.NAME, + AlibabaCloudSearchEmbeddingsTaskSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AlibabaCloudSearchSparseServiceSettings.NAME, + AlibabaCloudSearchSparseServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AlibabaCloudSearchSparseTaskSettings.NAME, + AlibabaCloudSearchSparseTaskSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AlibabaCloudSearchRerankServiceSettings.NAME, + AlibabaCloudSearchRerankServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AlibabaCloudSearchRerankTaskSettings.NAME, + AlibabaCloudSearchRerankTaskSettings::new + ) + ); + + } + private static void addEisNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 9d85bbf751250..dff93a63d0647 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -74,6 +74,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService; import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService; import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; @@ -237,6 +238,7 @@ public List getInferenceServiceFactories() { context -> new MistralService(httpFactory.get(), serviceComponents.get()), context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), + context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java new file mode 100644 index 0000000000000..218ca2ef39ed6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; + +import java.util.Map; +import java.util.Objects; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the alibaba cloud search model type. + */ +public class AlibabaCloudSearchActionCreator implements AlibabaCloudSearchActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public AlibabaCloudSearchActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map taskSettings, InputType inputType) { + var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, taskSettings, inputType); + + return new AlibabaCloudSearchEmbeddingsAction(sender, overriddenModel, serviceComponents); + } + + @Override + public ExecutableAction create(AlibabaCloudSearchSparseModel model, Map taskSettings, InputType inputType) { + var overriddenModel = AlibabaCloudSearchSparseModel.of(model, taskSettings, inputType); + + return new AlibabaCloudSearchSparseAction(sender, overriddenModel, serviceComponents); + } + + @Override + public ExecutableAction create(AlibabaCloudSearchRerankModel model, Map taskSettings) { + var overriddenModel = AlibabaCloudSearchRerankModel.of(model, taskSettings); + + return new AlibabaCloudSearchRerankAction(sender, overriddenModel, serviceComponents); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java new file mode 100644 index 0000000000000..69ae903c7b38f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; + +import java.util.Map; + +public interface AlibabaCloudSearchActionVisitor { + ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map taskSettings, InputType inputType); + + ExecutableAction create(AlibabaCloudSearchSparseModel model, Map taskSettings, InputType inputType); + + ExecutableAction create(AlibabaCloudSearchRerankModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchEmbeddingsAction.java new file mode 100644 index 0000000000000..7a22bbf6b4bfd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchEmbeddingsAction.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AlibabaCloudSearchEmbeddingsAction implements ExecutableAction { + private final AlibabaCloudSearchAccount account; + private final AlibabaCloudSearchEmbeddingsModel model; + private final String failedToSendRequestErrorMessage; + private final Sender sender; + private final AlibabaCloudSearchEmbeddingsRequestManager requestCreator; + + public AlibabaCloudSearchEmbeddingsAction(Sender sender, AlibabaCloudSearchEmbeddingsModel model, ServiceComponents serviceComponents) { + this.model = Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey()); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search text embeddings"); + this.requestCreator = AlibabaCloudSearchEmbeddingsRequestManager.of(account, model, serviceComponents.threadPool()); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + sender.send(requestCreator, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchRerankAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchRerankAction.java new file mode 100644 index 0000000000000..88229ce63463b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchRerankAction.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchRerankRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AlibabaCloudSearchRerankAction implements ExecutableAction { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchRerankAction.class); + + private final AlibabaCloudSearchAccount account; + private final AlibabaCloudSearchRerankModel model; + private final String failedToSendRequestErrorMessage; + private final Sender sender; + private final AlibabaCloudSearchRerankRequestManager requestCreator; + + public AlibabaCloudSearchRerankAction(Sender sender, AlibabaCloudSearchRerankModel model, ServiceComponents serviceComponents) { + this.model = Objects.requireNonNull(model); + this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey()); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search rerank"); + this.sender = Objects.requireNonNull(sender); + this.requestCreator = AlibabaCloudSearchRerankRequestManager.of(account, model, serviceComponents.threadPool()); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + sender.send(requestCreator, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchSparseAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchSparseAction.java new file mode 100644 index 0000000000000..2cd31ff83d200 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchSparseAction.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchSparseRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AlibabaCloudSearchSparseAction implements ExecutableAction { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchSparseAction.class); + + private final AlibabaCloudSearchAccount account; + private final AlibabaCloudSearchSparseModel model; + private final String failedToSendRequestErrorMessage; + private final Sender sender; + private final AlibabaCloudSearchSparseRequestManager requestCreator; + + public AlibabaCloudSearchSparseAction(Sender sender, AlibabaCloudSearchSparseModel model, ServiceComponents serviceComponents) { + this.model = Objects.requireNonNull(model); + this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey()); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search sparse embeddings"); + this.sender = Objects.requireNonNull(sender); + requestCreator = AlibabaCloudSearchSparseRequestManager.of(account, model, serviceComponents.threadPool()); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + sender.send(requestCreator, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchAccount.java new file mode 100644 index 0000000000000..6aabbe20cc355 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchAccount.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.alibabacloudsearch; + +import org.elasticsearch.common.settings.SecureString; + +import java.util.Objects; + +public record AlibabaCloudSearchAccount(SecureString apiKey) { + + public AlibabaCloudSearchAccount { + Objects.requireNonNull(apiKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchResponseHandler.java new file mode 100644 index 0000000000000..05d51372d9cdc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/alibabacloudsearch/AlibabaCloudSearchResponseHandler.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.alibabacloudsearch; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchErrorResponseEntity; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; + +/** + * Defines how to handle various errors returned from the AlibabaCloudSearch integration. + */ +public class AlibabaCloudSearchResponseHandler extends BaseResponseHandler { + + public AlibabaCloudSearchResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, AlibabaCloudSearchErrorResponseEntity::fromResponse); + } + + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + checkForFailureStatusCode(request, result); + checkForEmptyBody(throttlerManager, logger, request, result); + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode >= 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..55c699bf26e82 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class AlibabaCloudSearchEmbeddingsRequestManager extends AlibabaCloudSearchRequestManager { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchEmbeddingsRequestManager.class); + + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new AlibabaCloudSearchResponseHandler( + "alibaba cloud search text embedding", + AlibabaCloudSearchEmbeddingsResponseEntity::fromResponse + ); + } + + public static AlibabaCloudSearchEmbeddingsRequestManager of( + AlibabaCloudSearchAccount account, + AlibabaCloudSearchEmbeddingsModel model, + ThreadPool threadPool + ) { + return new AlibabaCloudSearchEmbeddingsRequestManager( + Objects.requireNonNull(account), + Objects.requireNonNull(model), + Objects.requireNonNull(threadPool) + ); + } + + private final AlibabaCloudSearchEmbeddingsModel model; + + private final AlibabaCloudSearchAccount account; + + private AlibabaCloudSearchEmbeddingsRequestManager( + AlibabaCloudSearchAccount account, + AlibabaCloudSearchEmbeddingsModel model, + ThreadPool threadPool + ) { + super(threadPool, model); + this.account = Objects.requireNonNull(account); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List input = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, input, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRequestManager.java new file mode 100644 index 0000000000000..c8ade15ac5057 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRequestManager.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchModel; + +import java.util.Objects; + +abstract class AlibabaCloudSearchRequestManager extends BaseRequestManager { + + protected AlibabaCloudSearchRequestManager(ThreadPool threadPool, AlibabaCloudSearchModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int apiKeyHash) { + public static RateLimitGrouping of(AlibabaCloudSearchModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.rateLimitServiceSettings().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java new file mode 100644 index 0000000000000..446db40aa5ae5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRerankRequest; +import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchRerankResponseEntity; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class AlibabaCloudSearchRerankRequestManager extends AlibabaCloudSearchRequestManager { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchRerankRequestManager.class); + private static final ResponseHandler HANDLER = createRerankHandler(); + + private static ResponseHandler createRerankHandler() { + return new AlibabaCloudSearchResponseHandler("alibaba cloud search rerank", AlibabaCloudSearchRerankResponseEntity::fromResponse); + } + + public static AlibabaCloudSearchRerankRequestManager of( + AlibabaCloudSearchAccount account, + AlibabaCloudSearchRerankModel model, + ThreadPool threadPool + ) { + return new AlibabaCloudSearchRerankRequestManager( + Objects.requireNonNull(account), + Objects.requireNonNull(model), + Objects.requireNonNull(threadPool) + ); + } + + private final AlibabaCloudSearchRerankModel model; + + private final AlibabaCloudSearchAccount account; + + private AlibabaCloudSearchRerankRequestManager( + AlibabaCloudSearchAccount account, + AlibabaCloudSearchRerankModel model, + ThreadPool threadPool + ) { + super(threadPool, model); + this.account = account; + this.model = model; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + AlibabaCloudSearchRerankRequest request = new AlibabaCloudSearchRerankRequest( + account, + rerankInput.getQuery(), + rerankInput.getChunks(), + model + ); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java new file mode 100644 index 0000000000000..b0cc524bb4cbe --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchSparseRequest; +import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchSparseResponseEntity; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class AlibabaCloudSearchSparseRequestManager extends AlibabaCloudSearchRequestManager { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchSparseRequestManager.class); + + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new AlibabaCloudSearchResponseHandler( + "alibaba cloud search sparse embedding", + AlibabaCloudSearchSparseResponseEntity::fromResponse + ); + } + + public static AlibabaCloudSearchSparseRequestManager of( + AlibabaCloudSearchAccount account, + AlibabaCloudSearchSparseModel model, + ThreadPool threadPool + ) { + return new AlibabaCloudSearchSparseRequestManager( + Objects.requireNonNull(account), + Objects.requireNonNull(model), + Objects.requireNonNull(threadPool) + ); + } + + private final AlibabaCloudSearchSparseModel model; + + private final AlibabaCloudSearchAccount account; + + private AlibabaCloudSearchSparseRequestManager( + AlibabaCloudSearchAccount account, + AlibabaCloudSearchSparseModel model, + ThreadPool threadPool + ) { + super(threadPool, model); + this.account = Objects.requireNonNull(account); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List input = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, input, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequest.java new file mode 100644 index 0000000000000..081854903405e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequest.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class AlibabaCloudSearchEmbeddingsRequest extends AlibabaCloudSearchRequest { + + private final AlibabaCloudSearchAccount account; + private final List input; + private final URI uri; + private final AlibabaCloudSearchEmbeddingsTaskSettings taskSettings; + private final String model; + private final String host; + private final String workspaceName; + private final String httpSchema; + private final String inferenceEntityId; + + public AlibabaCloudSearchEmbeddingsRequest( + AlibabaCloudSearchAccount account, + List input, + AlibabaCloudSearchEmbeddingsModel embeddingsModel + ) { + Objects.requireNonNull(embeddingsModel); + + this.account = Objects.requireNonNull(account); + this.input = Objects.requireNonNull(input); + taskSettings = embeddingsModel.getTaskSettings(); + model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); + host = embeddingsModel.getServiceSettings().getCommonSettings().getHost(); + workspaceName = embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(); + httpSchema = embeddingsModel.getServiceSettings().getCommonSettings().getHttpSchema() != null + ? embeddingsModel.getServiceSettings().getCommonSettings().getHttpSchema() + : "https"; + uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri); + inferenceEntityId = embeddingsModel.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new AlibabaCloudSearchEmbeddingsRequestEntity(input, taskSettings)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme(httpSchema) + .setHost(host) + .setPathSegments( + AlibabaCloudSearchUtils.VERSION_3, + AlibabaCloudSearchUtils.OPENAPI_PATH, + AlibabaCloudSearchUtils.WORKSPACE_PATH, + workspaceName, + AlibabaCloudSearchUtils.TEXT_EMBEDDING_PATH, + model + ) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..c2367aeff3070 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntity.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings.invalidInputTypeMessage; + +public record AlibabaCloudSearchEmbeddingsRequestEntity(List input, AlibabaCloudSearchEmbeddingsTaskSettings taskSettings) + implements + ToXContentObject { + + private static final String SEARCH_DOCUMENT = "document"; + private static final String SEARCH_QUERY = "query"; + + private static final String TEXTS_FIELD = "input"; + + static final String INPUT_TYPE_FIELD = "input_type"; + + public AlibabaCloudSearchEmbeddingsRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEXTS_FIELD, input); + + String inputType = covertToString(taskSettings.getInputType()); + if (inputType != null) { + builder.field(INPUT_TYPE_FIELD, inputType); + } + + builder.endObject(); + return builder; + } + + // default for testing + static String covertToString(InputType inputType) { + if (inputType == null) { + return null; + } + + return switch (inputType) { + case INGEST -> SEARCH_DOCUMENT; + case SEARCH -> SEARCH_QUERY; + default -> { + assert false : invalidInputTypeMessage(inputType); + yield null; + } + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRequest.java new file mode 100644 index 0000000000000..75fc12e1bad31 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRequest.java @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.elasticsearch.xpack.inference.external.request.Request; + +public abstract class AlibabaCloudSearchRequest implements Request { + private final long startTime; + + public AlibabaCloudSearchRequest() { + this.startTime = System.currentTimeMillis(); + } + + public long getStartTime() { + return startTime; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java new file mode 100644 index 0000000000000..878bcc6e6a0db --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class AlibabaCloudSearchRerankRequest implements Request { + private final AlibabaCloudSearchAccount account; + private final String query; + private final List input; + private final URI uri; + private final AlibabaCloudSearchRerankTaskSettings taskSettings; + private final String model; + private final String host; + private final String workspaceName; + private final String httpSchema; + private final String inferenceEntityId; + + public AlibabaCloudSearchRerankRequest( + AlibabaCloudSearchAccount account, + String query, + List input, + AlibabaCloudSearchRerankModel rerankModel + ) { + Objects.requireNonNull(rerankModel); + + this.account = Objects.requireNonNull(account); + this.query = Objects.requireNonNull(query); + this.input = Objects.requireNonNull(input); + taskSettings = rerankModel.getTaskSettings(); + model = rerankModel.getServiceSettings().getCommonSettings().modelId(); + host = rerankModel.getServiceSettings().getCommonSettings().getHost(); + workspaceName = rerankModel.getServiceSettings().getCommonSettings().getWorkspaceName(); + httpSchema = rerankModel.getServiceSettings().getCommonSettings().getHttpSchema() != null + ? rerankModel.getServiceSettings().getCommonSettings().getHttpSchema() + : "https"; + uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri); + inferenceEntityId = rerankModel.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, taskSettings)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme(httpSchema) + .setHost(host) + .setPathSegments( + AlibabaCloudSearchUtils.VERSION_3, + AlibabaCloudSearchUtils.OPENAPI_PATH, + AlibabaCloudSearchUtils.WORKSPACE_PATH, + workspaceName, + AlibabaCloudSearchUtils.RERANK_PATH, + model + ) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java new file mode 100644 index 0000000000000..054e373e3e525 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record AlibabaCloudSearchRerankRequestEntity(String query, List input, AlibabaCloudSearchRerankTaskSettings taskSettings) + implements + ToXContentObject { + + private static final String SEARCH_QUERY = "query"; + private static final String TEXTS_FIELD = "docs"; + + public AlibabaCloudSearchRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + builder.field(SEARCH_QUERY, query); + builder.field(TEXTS_FIELD, input); + } + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequest.java new file mode 100644 index 0000000000000..c7b4c314b07a7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequest.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class AlibabaCloudSearchSparseRequest extends AlibabaCloudSearchRequest { + + private final AlibabaCloudSearchAccount account; + private final List input; + private final URI uri; + private final AlibabaCloudSearchSparseTaskSettings taskSettings; + private final String model; + private final String host; + private final String workspaceName; + private final String httpSchema; + private final String inferenceEntityId; + + public AlibabaCloudSearchSparseRequest( + AlibabaCloudSearchAccount account, + List input, + AlibabaCloudSearchSparseModel sparseEmbeddingsModel + ) { + Objects.requireNonNull(sparseEmbeddingsModel); + + this.account = Objects.requireNonNull(account); + this.input = Objects.requireNonNull(input); + taskSettings = sparseEmbeddingsModel.getTaskSettings(); + model = sparseEmbeddingsModel.getServiceSettings().getCommonSettings().modelId(); + host = sparseEmbeddingsModel.getServiceSettings().getCommonSettings().getHost(); + workspaceName = sparseEmbeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(); + httpSchema = sparseEmbeddingsModel.getServiceSettings().getCommonSettings().getHttpSchema() != null + ? sparseEmbeddingsModel.getServiceSettings().getCommonSettings().getHttpSchema() + : "https"; + uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri); + inferenceEntityId = sparseEmbeddingsModel.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new AlibabaCloudSearchSparseRequestEntity(input, taskSettings)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme(httpSchema) + .setHost(host) + .setPathSegments( + AlibabaCloudSearchUtils.VERSION_3, + AlibabaCloudSearchUtils.OPENAPI_PATH, + AlibabaCloudSearchUtils.WORKSPACE_PATH, + workspaceName, + AlibabaCloudSearchUtils.SPARSE_EMBEDDING_PATH, + model + ) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntity.java new file mode 100644 index 0000000000000..3aec226bfc277 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntity.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record AlibabaCloudSearchSparseRequestEntity(List input, AlibabaCloudSearchSparseTaskSettings taskSettings) + implements + ToXContentObject { + + private static final String TEXTS_FIELD = "input"; + + static final String INPUT_TYPE_FIELD = "input_type"; + + static final String RETURN_TOKEN_FIELD = "return_token"; + + public AlibabaCloudSearchSparseRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEXTS_FIELD, input); + String inputType = AlibabaCloudSearchEmbeddingsRequestEntity.covertToString(taskSettings.getInputType()); + if (inputType != null) { + builder.field(INPUT_TYPE_FIELD, inputType); + } + if (taskSettings.isReturnToken() != null) { + builder.field(RETURN_TOKEN_FIELD, taskSettings.isReturnToken()); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java new file mode 100644 index 0000000000000..7d671471976f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +public class AlibabaCloudSearchUtils { + public static final String SERVICE_NAME = "alibabacloud-ai-search"; + public static final String VERSION_3 = "v3"; + public static final String OPENAPI_PATH = "openapi"; + public static final String WORKSPACE_PATH = "workspaces"; + public static final String TEXT_EMBEDDING_PATH = "text-embedding"; + public static final String SPARSE_EMBEDDING_PATH = "text-sparse-embedding"; + public static final String RERANK_PATH = "ranker"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..33fa645b107bc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntity.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class AlibabaCloudSearchEmbeddingsResponseEntity extends AlibabaCloudSearchResponseEntity { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in AlibabaCloud Search embeddings response"; + + /** + * Parses the AlibabaCloud Search embedding json response. + * For a request like: + * + *
+     * 
+     * {
+     *  "texts": ["hello this is my name", "I wish I was there!"]
+     * }
+     * 
+     * 
+ * + * The response would look like: + * + *
+     * 
+     * {
+     *     "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4",
+     *     "latency": 38,
+     *     "usage": {
+     *         "token_count": 3072
+     *     },
+     *     "result": {
+     *         "embeddings": [
+     *             {
+     *                 "index": 0,
+     *                 "embedding": [
+     *                     -0.02868066355586052,
+     *                     0.022033605724573135,
+     *                     -0.0417383536696434,
+     *                     -0.044081952422857285,
+     *                     0.02141784131526947,
+     *                     -8.240503375418484E-4,
+     *                     -0.01309406291693449,
+     *                     -0.02169642224907875,
+     *                     -0.03996409475803375,
+     *                     0.008053945377469063,
+     *                     ...
+     *                     -0.05131729692220688,
+     *                     -0.016595875844359398
+     *                 ]
+     *             }
+     *         ]
+     *     }
+     * }
+     * 
+     * 
+ */ + public static InferenceTextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + return fromResponse(request, response, parser -> { + positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingList = XContentParserUtils.parseList( + parser, + AlibabaCloudSearchEmbeddingsResponseEntity::parseEmbeddingObject + ); + + return new InferenceTextEmbeddingFloatResults(embeddingList); + }); + } + + private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseEmbeddingObject(XContentParser parser) + throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingValues = XContentParserUtils.parseList(parser, AlibabaCloudSearchEmbeddingsResponseEntity::parseEmbeddingList); + + // the parser is currently sitting at an ARRAY_END so go to the next token + parser.nextToken(); + // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array + parser.skipChildren(); + + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValues); + } + + private static float parseEmbeddingList(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + return parser.floatValue(); + } + + private AlibabaCloudSearchEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntity.java new file mode 100644 index 0000000000000..77a0c6ecc7cc8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntity.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage; + +public class AlibabaCloudSearchErrorResponseEntity implements ErrorMessage { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchErrorResponseEntity.class); + + private final String errorMessage; + + private AlibabaCloudSearchErrorResponseEntity(String errorMessage) { + this.errorMessage = errorMessage; + } + + @Override + public String getErrorMessage() { + return errorMessage; + } + + /** + * An example error response for invalid auth would look like + * + * { + * "request_id": "651B3087-8A07-xxxx-xxxx-9C4E7B60F52D", + * "latency": 0, + * "code": "InvalidParameter", + * "message": "JSON parse error: Cannot deserialize value of type `InputType` from String \"xxx\"" + * } + * + * + * + * @param response The error response + * @return An error entity if the response is JSON with the above structure + * or null if the response does not contain the message field + */ + public static AlibabaCloudSearchErrorResponseEntity fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + if (logger.isDebugEnabled()) { + logger.debug("Received error response: {}", responseMap); + } + + var message = (String) responseMap.get("message"); + if (message != null) { + return new AlibabaCloudSearchErrorResponseEntity(message); + } + } catch (Exception e) { + // swallow the error + } + + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntity.java new file mode 100644 index 0000000000000..9a18c56e86475 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntity.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class AlibabaCloudSearchRerankResponseEntity { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in AlibabaCloud Search rerank response"; + + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchRerankResponseEntity.class); + + /** + * Parses the AlibabaCloud Search rerank json response. + * For a request like: + * + *
+     * 
+     * {
+     *  "query": "上海有什么好玩的",
+     *  "docs" : ["上海有许多好玩的地方",
+     *             "北京有许多好玩的地方"]
+     * }
+     * 
+     * 
+ * + * The response would look like: + * + *
+     * 
+     *     {
+     *   "request_id": "450fcb80-f796-xxxx-xxxx-e1e86d29aa9f",
+     *   "latency": 564.903929,
+     *   "usage": {
+     *     "doc_count": 2
+     *   }
+     *   "result": {
+     *    "scores":[
+     *      {
+     *        "index":1,
+     *        "score": 1.37
+     *      },
+     *      {
+     *        "index":0,
+     *        "score": -0.3
+     *      }
+     *    ]
+     *   }
+     * }
+     * 
+     * 
+ */ + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "result", FAILED_TO_FIND_FIELD_TEMPLATE); + + positionParserAtTokenAfterField(jsonParser, "scores", FAILED_TO_FIND_FIELD_TEMPLATE); + + token = jsonParser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + return new RankedDocsResults(parseList(jsonParser, AlibabaCloudSearchRerankResponseEntity::parseRankedDocObject)); + } else { + throwUnknownToken(token, jsonParser); + } + + // This should never be reached. The above code should either return successfully or hit the throwUnknownToken + // or throw a parsing exception + throw new IllegalStateException("Reached an invalid state while parsing the AlibabaCloudSearch response"); + } + } + + private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + int index = -1; + float score = -1; + String documentText = null; + parser.nextToken(); + while (parser.currentToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "index": + parser.nextToken(); // move to VALUE_NUMBER + index = parser.intValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "score": + parser.nextToken(); // move to VALUE_NUMBER + score = parser.floatValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + default: + throwUnknownField(parser.currentName(), parser); + } + } else { + parser.nextToken(); + } + } + + if (index == -1) { + logger.warn("Failed to find required field [index] in AlibabaCloudSearch rerank response"); + } + if (score == -1) { + logger.warn("Failed to find required field [relevance_score] in AlibabaCloudSearch rerank response"); + } + // documentText may or may not be present depending on the request parameter + + return new RankedDocsResults.RankedDoc(index, score, documentText); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchResponseEntity.java new file mode 100644 index 0000000000000..156f2c1d4078c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchResponseEntity.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest; + +import java.io.IOException; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public abstract class AlibabaCloudSearchResponseEntity { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchResponseEntity.class); + + public static R fromResponse(Request request, HttpResult response, CheckedFunction function) + throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + AlibabaCloudSearchRequest alibabaCloudSearchRequest = (AlibabaCloudSearchRequest) request; + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(parser); + + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser); + + R result = null; + String requestID = null; + float latency = 0; + Map usage = null; + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String currentFieldName = parser.currentName(); + parser.nextToken(); + switch (currentFieldName) { + case "result": + result = function.apply(parser); + break; + case "request_id": + requestID = parser.text(); + break; + case "latency": + latency = parser.floatValue(); + break; + case "usage": + usage = parser.map(); + break; + default: + parser.skipChildren(); + } + } + + logger.debug( + "AlibabaCloud Search uri [{}] response: request_id [{}], latency [{}ms], client cost [{}ms], usage [{}]", + request.getURI().getPath(), + requestID, + latency, + System.currentTimeMillis() - alibabaCloudSearchRequest.getStartTime(), + usage + ); + return result; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntity.java new file mode 100644 index 0000000000000..c903bfd188116 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntity.java @@ -0,0 +1,199 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class AlibabaCloudSearchSparseResponseEntity extends AlibabaCloudSearchResponseEntity { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in AlibabaCloud Search sparse embeddings response"; + + /** + * Parses the AlibabaCloud Search sparse embedding json response. + * For a request like: + * + *
+     * 
+     * {
+     *  "texts": ["hello this is my name", "I wish I was there!"]
+     * }
+     * 
+     * 
+ * + * The response would look like: + * + *
+     * 
+     *     {
+     *   "request_id": "DDC4306F-xxxx-xxxx-xxxx-92C5CEA756A0",
+     *   "latency": 25,
+     *   "usage": {
+     *     "token_count": 11
+     *   },
+     *   "result": {
+     *     "sparse_embeddings": [
+     *       {
+     *         "index": 0,
+     *         "embedding": [
+     *           {
+     *             "token_id": 6,
+     *             "weight": 0.1014404296875,
+     *             "token": ""
+     *           },
+     *           {
+     *             "token_id": 163040,
+     *             "weight": 0.2841796875,
+     *             "token": "科学技术"
+     *           },
+     *           {
+     *             "token_id": 354,
+     *             "weight": 0.1431884765625,
+     *             "token": "是"
+     *           },
+     *           {
+     *             "token_id": 5998,
+     *             "weight": 0.1614990234375,
+     *             "token": "第一"
+     *           },
+     *           {
+     *             "token_id": 8550,
+     *             "weight": 0.239013671875,
+     *             "token": "生产"
+     *           },
+     *           {
+     *             "token_id": 2017,
+     *             "weight": 0.161376953125,
+     *             "token": "力"
+     *           }
+     *         ]
+     *       },
+     *       {
+     *         "index": 1,
+     *         "embedding": [
+     *           {
+     *             "token_id": 9803,
+     *             "weight": 0.1949462890625,
+     *             "token": "open"
+     *           },
+     *           {
+     *             "token_id": 86250,
+     *             "weight": 0.317138671875,
+     *             "token": "search"
+     *           },
+     *           {
+     *             "token_id": 5889,
+     *             "weight": 0.175048828125,
+     *             "token": "产品"
+     *           },
+     *           {
+     *             "token_id": 2564,
+     *             "weight": 0.1163330078125,
+     *             "token": "文"
+     *           },
+     *           {
+     *             "token_id": 59529,
+     *             "weight": 0.16650390625,
+     *             "token": "档"
+     *           }
+     *         ]
+     *       }
+     *     ]
+     *   }
+     * }
+     * 
+     * 
+ */ + public static SparseEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException { + return fromResponse(request, response, jsonParser -> { + positionParserAtTokenAfterField(jsonParser, "sparse_embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingList = XContentParserUtils.parseList( + jsonParser, + AlibabaCloudSearchSparseResponseEntity::parseEmbeddingObject + ); + + return new SparseEmbeddingResults(embeddingList); + }); + } + + private static SparseEmbeddingResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + + List tokens = parseWeightedTokenList(parser); + + // the parser is currently sitting at an ARRAY_END so go to the next token + parser.nextToken(); + // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array + parser.skipChildren(); + + return new SparseEmbeddingResults.Embedding(tokens, false); + } + + private static List parseWeightedTokenList(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + if (parser.nextToken() == XContentParser.Token.END_ARRAY) { + return List.of(); + } + final ArrayList list = new ArrayList<>(); + do { + WeightedToken token = parseEmbeddingList(parser); + if (token != null) { + list.add(token); + } + } while (parser.nextToken() != XContentParser.Token.END_ARRAY); + return list; + } + + private static WeightedToken parseEmbeddingList(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser); + Map values = parser.map(); + Object tokenName; + if (values.containsKey("token")) { + tokenName = values.get("token"); + } else { + tokenName = values.get("token_id"); + } + float weight = Float.parseFloat(values.get("weight").toString()); + + if (invalidToken(tokenName) || weight <= 0.0f) { + return null; + } + + return new WeightedToken(tokenName.toString(), weight); + } + + private static boolean invalidToken(Object tokenName) { + if (tokenName == null) { + return true; + } + + String token = tokenName.toString(); + if (token.isEmpty() || token.contains(".")) { + return true; + } + + return false; + } + + private AlibabaCloudSearchSparseResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchModel.java new file mode 100644 index 0000000000000..e26953754b1d6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchModel.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; + +import java.util.Map; +import java.util.Objects; + +public abstract class AlibabaCloudSearchModel extends Model { + private final AlibabaCloudSearchRateLimitServiceSettings rateLimitServiceSettings; + + public AlibabaCloudSearchModel( + ModelConfigurations configurations, + ModelSecrets secrets, + AlibabaCloudSearchRateLimitServiceSettings rateLimitServiceSettings + ) { + super(configurations, secrets); + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + } + + protected AlibabaCloudSearchModel(AlibabaCloudSearchModel model, TaskSettings taskSettings) { + super(model, taskSettings); + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + protected AlibabaCloudSearchModel(AlibabaCloudSearchModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + public abstract ExecutableAction accept(AlibabaCloudSearchActionVisitor creator, Map taskSettings, InputType inputType); + + public AlibabaCloudSearchRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchRateLimitServiceSettings.java new file mode 100644 index 0000000000000..db09f1d66e77b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchRateLimitServiceSettings.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface AlibabaCloudSearchRateLimitServiceSettings { + RateLimitSettings rateLimitSettings(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java new file mode 100644 index 0000000000000..0c48c99b4b81e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -0,0 +1,318 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class AlibabaCloudSearchService extends SenderService { + public static final String NAME = AlibabaCloudSearchUtils.SERVICE_NAME; + + public AlibabaCloudSearchService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + AlibabaCloudSearchModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static AlibabaCloudSearchModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> new AlibabaCloudSearchEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + case SPARSE_EMBEDDING -> new AlibabaCloudSearchSparseModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + case RERANK -> new AlibabaCloudSearchRerankModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public AlibabaCloudSearchModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelWithoutLoggingDeprecations( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + return createModelWithoutLoggingDeprecations( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof AlibabaCloudSearchModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model; + var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); + + var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType); + action.execute(new QueryAndDocsInputs(query, input), timeout, listener); + } + + @Override + public void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof AlibabaCloudSearchModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model; + var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); + + var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + @Nullable String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME)); + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) + ); + } else { + checkAlibabaCloudSearchServiceConfig(model, this, listener); + } + } + + private AlibabaCloudSearchEmbeddingsModel updateModelWithEmbeddingDetails(AlibabaCloudSearchEmbeddingsModel model, int embeddingSize) { + AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings = new AlibabaCloudSearchEmbeddingsServiceSettings( + new AlibabaCloudSearchServiceSettings( + model.getServiceSettings().getCommonSettings().modelId(), + model.getServiceSettings().getCommonSettings().getHost(), + model.getServiceSettings().getCommonSettings().getWorkspaceName(), + model.getServiceSettings().getCommonSettings().getHttpSchema(), + model.getServiceSettings().getCommonSettings().rateLimitSettings() + ), + SimilarityMeasure.DOT_PRODUCT, + embeddingSize, + model.getServiceSettings().getMaxInputTokens() + ); + + return new AlibabaCloudSearchEmbeddingsModel(model, serviceSettings); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + /** + * For other models except of text embedding + * check the model's service settings and task settings + * + * @param model The new model + * @param service The inferenceService + * @param listener The listener + */ + private void checkAlibabaCloudSearchServiceConfig(Model model, InferenceService service, ActionListener listener) { + String input = ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_INPUT; + String query = model.getTaskType().equals(TaskType.RERANK) ? ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_QUERY : null; + + service.infer( + model, + query, + List.of(input), + Map.of(), + InputType.INGEST, + DEFAULT_TIMEOUT, + listener.delegateFailureAndWrap((delegate, r) -> { + listener.onResponse(model); + }) + ); + } + + private static final String ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_INPUT = "input"; + private static final String ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_QUERY = "query"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettings.java new file mode 100644 index 0000000000000..3500bdf814e16 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettings.java @@ -0,0 +1,193 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class AlibabaCloudSearchServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + AlibabaCloudSearchRateLimitServiceSettings { + + public static final String NAME = "alibabacloud_search_service_settings"; + public static final String SERVICE_ID = "service_id"; + public static final String HOST = "host"; + public static final String WORKSPACE_NAME = "workspace"; + public static final String HTTP_SCHEMA_NAME = "http_schema"; + + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1_000); + + public static AlibabaCloudSearchServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, SERVICE_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String host = extractRequiredString(map, HOST, ModelConfigurations.SERVICE_SETTINGS, validationException); + var workspaceName = extractRequiredString(map, WORKSPACE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); + var httpSchema = extractOptionalString(map, HTTP_SCHEMA_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); + + if (httpSchema != null) { + var validSchemas = Set.of("https", "http"); + if (validSchemas.contains(httpSchema) == false) { + validationException.addValidationError("Invalid value for [http_schema]. Must be one of [https, http]"); + } + } + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AlibabaCloudSearchService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AlibabaCloudSearchServiceSettings(modelId, host, workspaceName, httpSchema, rateLimitSettings); + } + + private final String serviceId; + private final String host; + private final String workspaceName; + private final String httpSchema; + private final RateLimitSettings rateLimitSettings; + + public AlibabaCloudSearchServiceSettings( + String serviceId, + String host, + String workspaceName, + @Nullable String httpSchema, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.serviceId = serviceId; + this.host = host; + this.workspaceName = workspaceName; + this.httpSchema = httpSchema; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public AlibabaCloudSearchServiceSettings(StreamInput in) throws IOException { + serviceId = in.readString(); + host = in.readString(); + workspaceName = in.readString(); + httpSchema = in.readOptionalString(); + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public String modelId() { + return serviceId; + } + + public String getHost() { + return host; + } + + public String getWorkspaceName() { + return workspaceName; + } + + public String getHttpSchema() { + return httpSchema; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { + return toXContentFragmentOfExposedFields(builder, params); + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (serviceId != null) { + builder.field(SERVICE_ID, serviceId); + } + builder.field(HOST, host); + builder.field(WORKSPACE_NAME, workspaceName); + if (httpSchema != null) { + builder.field(HTTP_SCHEMA_NAME, httpSchema); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(serviceId); + out.writeString(host); + out.writeString(workspaceName); + out.writeOptionalString(httpSchema); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AlibabaCloudSearchServiceSettings that = (AlibabaCloudSearchServiceSettings) o; + return Objects.equals(serviceId, that.serviceId) + && Objects.equals(host, that.host) + && Objects.equals(workspaceName, that.workspaceName) + && Objects.equals(httpSchema, that.httpSchema); + } + + @Override + public int hashCode() { + return Objects.hash(serviceId, host, workspaceName, httpSchema); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java new file mode 100644 index 0000000000000..87e5e59ae3434 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class AlibabaCloudSearchEmbeddingsModel extends AlibabaCloudSearchModel { + public static AlibabaCloudSearchEmbeddingsModel of( + AlibabaCloudSearchEmbeddingsModel model, + Map taskSettings, + InputType inputType + ) { + var requestTaskSettings = AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(taskSettings); + return new AlibabaCloudSearchEmbeddingsModel( + model, + AlibabaCloudSearchEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType) + ); + } + + public AlibabaCloudSearchEmbeddingsModel( + String modelId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + modelId, + taskType, + service, + AlibabaCloudSearchEmbeddingsServiceSettings.fromMap(serviceSettings, context), + AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + AlibabaCloudSearchEmbeddingsModel( + String modelId, + TaskType taskType, + String service, + AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings, + AlibabaCloudSearchEmbeddingsTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings.getCommonSettings() + ); + } + + private AlibabaCloudSearchEmbeddingsModel( + AlibabaCloudSearchEmbeddingsModel model, + AlibabaCloudSearchEmbeddingsTaskSettings taskSettings + ) { + super(model, taskSettings); + } + + public AlibabaCloudSearchEmbeddingsModel( + AlibabaCloudSearchEmbeddingsModel model, + AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings + ) { + super(model, serviceSettings); + } + + @Override + public AlibabaCloudSearchEmbeddingsServiceSettings getServiceSettings() { + return (AlibabaCloudSearchEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public AlibabaCloudSearchEmbeddingsTaskSettings getTaskSettings() { + return (AlibabaCloudSearchEmbeddingsTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings, inputType); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..76dfd01f333da --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java @@ -0,0 +1,152 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class AlibabaCloudSearchEmbeddingsServiceSettings implements ServiceSettings { + public static final String NAME = "alibabacloud_search_embeddings_service_settings"; + + public static AlibabaCloudSearchEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = AlibabaCloudSearchServiceSettings.fromMap(map, context); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AlibabaCloudSearchEmbeddingsServiceSettings(commonServiceSettings, similarity, dims, maxInputTokens); + } + + private final AlibabaCloudSearchServiceSettings commonSettings; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + + public AlibabaCloudSearchEmbeddingsServiceSettings( + AlibabaCloudSearchServiceSettings commonSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + this.commonSettings = commonSettings; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + } + + public AlibabaCloudSearchEmbeddingsServiceSettings(StreamInput in) throws IOException { + commonSettings = new AlibabaCloudSearchServiceSettings(in); + similarity = in.readOptionalEnum(SimilarityMeasure.class); + dimensions = in.readOptionalVInt(); + maxInputTokens = in.readOptionalVInt(); + } + + public AlibabaCloudSearchServiceSettings getCommonSettings() { + return commonSettings; + } + + public SimilarityMeasure getSimilarity() { + return similarity; + } + + public Integer getDimensions() { + return dimensions; + } + + public Integer getMaxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + commonSettings.toXContentFragment(builder, params); + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + builder.endObject(); + return builder; + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + out.writeOptionalEnum(similarity); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AlibabaCloudSearchEmbeddingsServiceSettings that = (AlibabaCloudSearchEmbeddingsServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings) + && Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..abfd49940b67b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; + +/** + * Defines the task settings for the alibaba cloud search text embeddings service. + * + *

+ * + * See api docs for details. + *

+ */ +public class AlibabaCloudSearchEmbeddingsTaskSettings implements TaskSettings { + + public static final String NAME = "alibabacloud_search_embeddings_task_settings"; + public static final AlibabaCloudSearchEmbeddingsTaskSettings EMPTY_SETTINGS = new AlibabaCloudSearchEmbeddingsTaskSettings( + (InputType) null + ); + static final String INPUT_TYPE = "input_type"; + static final EnumSet VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH); + + public static AlibabaCloudSearchEmbeddingsTaskSettings fromMap(Map map) { + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + InputType inputType = extractOptionalEnum( + map, + INPUT_TYPE, + ModelConfigurations.TASK_SETTINGS, + InputType::fromString, + VALID_REQUEST_VALUES, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AlibabaCloudSearchEmbeddingsTaskSettings(inputType); + } + + /** + * Creates a new {@link AlibabaCloudSearchEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters. + * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED. + * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null. + *

+ * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to + * originalSettings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @param requestInputType the input type passed in the request parameters + * @return a constructed {@link AlibabaCloudSearchEmbeddingsTaskSettings} + */ + public static AlibabaCloudSearchEmbeddingsTaskSettings of( + AlibabaCloudSearchEmbeddingsTaskSettings originalSettings, + AlibabaCloudSearchEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType); + + return new AlibabaCloudSearchEmbeddingsTaskSettings(inputTypeToUse); + } + + private static InputType getValidInputType( + AlibabaCloudSearchEmbeddingsTaskSettings originalSettings, + AlibabaCloudSearchEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + InputType inputTypeToUse = originalSettings.inputType; + + if (VALID_REQUEST_VALUES.contains(requestInputType)) { + inputTypeToUse = requestInputType; + } else if (requestTaskSettings.inputType != null) { + inputTypeToUse = requestTaskSettings.inputType; + } + + return inputTypeToUse; + } + + private final InputType inputType; + + public AlibabaCloudSearchEmbeddingsTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalEnum(InputType.class)); + } + + public AlibabaCloudSearchEmbeddingsTaskSettings(@Nullable InputType inputType) { + validateInputType(inputType); + this.inputType = inputType; + } + + private static void validateInputType(InputType inputType) { + if (inputType == null) { + return; + } + + assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (inputType != null) { + builder.field(INPUT_TYPE, inputType); + } + + builder.endObject(); + return builder; + } + + public InputType getInputType() { + return inputType; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(inputType); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AlibabaCloudSearchEmbeddingsTaskSettings that = (AlibabaCloudSearchEmbeddingsTaskSettings) o; + return Objects.equals(inputType, that.inputType); + } + + @Override + public int hashCode() { + return Objects.hash(inputType); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankModel.java new file mode 100644 index 0000000000000..a9152b6edd4c5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankModel.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class AlibabaCloudSearchRerankModel extends AlibabaCloudSearchModel { + public static AlibabaCloudSearchRerankModel of(AlibabaCloudSearchRerankModel model, Map taskSettings) { + var requestTaskSettings = AlibabaCloudSearchRerankTaskSettings.fromMap(taskSettings); + return new AlibabaCloudSearchRerankModel( + model, + AlibabaCloudSearchRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings) + ); + } + + public AlibabaCloudSearchRerankModel( + String modelId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + modelId, + taskType, + service, + AlibabaCloudSearchRerankServiceSettings.fromMap(serviceSettings, context), + AlibabaCloudSearchRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + AlibabaCloudSearchRerankModel( + String modelId, + TaskType taskType, + String service, + AlibabaCloudSearchRerankServiceSettings serviceSettings, + AlibabaCloudSearchRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings.getCommonSettings() + ); + } + + private AlibabaCloudSearchRerankModel(AlibabaCloudSearchRerankModel model, AlibabaCloudSearchRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + public AlibabaCloudSearchRerankModel(AlibabaCloudSearchRerankModel model, AlibabaCloudSearchRerankServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public AlibabaCloudSearchRerankServiceSettings getServiceSettings() { + return (AlibabaCloudSearchRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public AlibabaCloudSearchRerankTaskSettings getTaskSettings() { + return (AlibabaCloudSearchRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java new file mode 100644 index 0000000000000..42c7238aefa7f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class AlibabaCloudSearchRerankServiceSettings implements ServiceSettings { + public static final String NAME = "alibabacloud_search_rerank_service_settings"; + + public static AlibabaCloudSearchRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = AlibabaCloudSearchServiceSettings.fromMap(map, context); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AlibabaCloudSearchRerankServiceSettings(commonServiceSettings); + } + + private final AlibabaCloudSearchServiceSettings commonSettings; + + public AlibabaCloudSearchRerankServiceSettings(AlibabaCloudSearchServiceSettings commonSettings) { + this.commonSettings = commonSettings; + } + + public AlibabaCloudSearchRerankServiceSettings(StreamInput in) throws IOException { + commonSettings = new AlibabaCloudSearchServiceSettings(in); + } + + public AlibabaCloudSearchServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + commonSettings.toXContentFragment(builder, params); + builder.endObject(); + return builder; + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AlibabaCloudSearchRerankServiceSettings that = (AlibabaCloudSearchRerankServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java new file mode 100644 index 0000000000000..e9fb468eab7fb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +/** + * Defines the task settings for the AlibabaCloudSearch rerank service. + * + *

+ * See api docs for details. + *

+ */ +public class AlibabaCloudSearchRerankTaskSettings implements TaskSettings { + public static final String NAME = "alibabacloud_search_rerank_task_settings"; + + static final AlibabaCloudSearchRerankTaskSettings EMPTY_SETTINGS = new AlibabaCloudSearchRerankTaskSettings(); + + public static AlibabaCloudSearchRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(); + } + + /** + * Creates a new {@link AlibabaCloudSearchRerankTaskSettings} + * by preferring non-null fields from the request settings over the original settings. + * @return a constructed {@link AlibabaCloudSearchRerankTaskSettings} + */ + public static AlibabaCloudSearchRerankTaskSettings of( + AlibabaCloudSearchRerankTaskSettings originalSettings, + AlibabaCloudSearchRerankTaskSettings requestTaskSettings + ) { + return new AlibabaCloudSearchRerankTaskSettings(); + } + + public static AlibabaCloudSearchRerankTaskSettings of() { + return new AlibabaCloudSearchRerankTaskSettings(); + } + + public AlibabaCloudSearchRerankTaskSettings(StreamInput in) { + this(); + } + + public AlibabaCloudSearchRerankTaskSettings() {} + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException {} + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hash(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java new file mode 100644 index 0000000000000..b551ba389136b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class AlibabaCloudSearchSparseModel extends AlibabaCloudSearchModel { + public static AlibabaCloudSearchSparseModel of( + AlibabaCloudSearchSparseModel model, + Map taskSettings, + InputType inputType + ) { + var requestTaskSettings = AlibabaCloudSearchSparseTaskSettings.fromMap(taskSettings); + return new AlibabaCloudSearchSparseModel( + model, + AlibabaCloudSearchSparseTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType) + ); + } + + public AlibabaCloudSearchSparseModel( + String modelId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + modelId, + taskType, + service, + AlibabaCloudSearchSparseServiceSettings.fromMap(serviceSettings, context), + AlibabaCloudSearchSparseTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + AlibabaCloudSearchSparseModel( + String modelId, + TaskType taskType, + String service, + AlibabaCloudSearchSparseServiceSettings serviceSettings, + AlibabaCloudSearchSparseTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings.getCommonSettings() + ); + } + + private AlibabaCloudSearchSparseModel(AlibabaCloudSearchSparseModel model, AlibabaCloudSearchSparseTaskSettings taskSettings) { + super(model, taskSettings); + } + + public AlibabaCloudSearchSparseModel(AlibabaCloudSearchSparseModel model, AlibabaCloudSearchSparseServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public AlibabaCloudSearchSparseServiceSettings getServiceSettings() { + return (AlibabaCloudSearchSparseServiceSettings) super.getServiceSettings(); + } + + @Override + public AlibabaCloudSearchSparseTaskSettings getTaskSettings() { + return (AlibabaCloudSearchSparseTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings, inputType); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettings.java new file mode 100644 index 0000000000000..fe44c936c4e61 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettings.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class AlibabaCloudSearchSparseServiceSettings implements ServiceSettings { + public static final String NAME = "alibabacloud_search_sparse_embeddings_service_settings"; + + public static AlibabaCloudSearchSparseServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = AlibabaCloudSearchServiceSettings.fromMap(map, context); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AlibabaCloudSearchSparseServiceSettings(commonServiceSettings); + } + + private final AlibabaCloudSearchServiceSettings commonSettings; + + public AlibabaCloudSearchSparseServiceSettings(AlibabaCloudSearchServiceSettings commonSettings) { + this.commonSettings = commonSettings; + } + + public AlibabaCloudSearchSparseServiceSettings(StreamInput in) throws IOException { + commonSettings = new AlibabaCloudSearchServiceSettings(in); + } + + public AlibabaCloudSearchServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + commonSettings.toXContentFragment(builder, params); + builder.endObject(); + return builder; + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AlibabaCloudSearchSparseServiceSettings that = (AlibabaCloudSearchSparseServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java new file mode 100644 index 0000000000000..2b1e9ace1b24c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; + +/** + * Defines the task settings for the alibabacloud search text sparse embeddings service. + * + *

+ * + * See api docs for details. + *

+ */ +public class AlibabaCloudSearchSparseTaskSettings implements TaskSettings { + + public static final String NAME = "alibabacloud_search_sparse_embeddings_task_settings"; + public static final AlibabaCloudSearchSparseTaskSettings EMPTY_SETTINGS = new AlibabaCloudSearchSparseTaskSettings(null, null); + static final String INPUT_TYPE = "input_type"; + static final String RETURN_TOKEN = "return_token"; + static final EnumSet VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH); + + public static AlibabaCloudSearchSparseTaskSettings fromMap(Map map) { + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + InputType inputType = extractOptionalEnum( + map, + INPUT_TYPE, + ModelConfigurations.TASK_SETTINGS, + InputType::fromString, + VALID_REQUEST_VALUES, + validationException + ); + + Boolean returnToken = extractOptionalBoolean(map, RETURN_TOKEN, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AlibabaCloudSearchSparseTaskSettings(inputType, returnToken); + } + + /** + * Creates a new {@link AlibabaCloudSearchSparseTaskSettings} by preferring non-null fields from the provided parameters. + * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED. + * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null. + *

+ * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to + * originalSettings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @param requestInputType the input type passed in the request parameters + * @return a constructed {@link AlibabaCloudSearchSparseTaskSettings} + */ + public static AlibabaCloudSearchSparseTaskSettings of( + AlibabaCloudSearchSparseTaskSettings originalSettings, + AlibabaCloudSearchSparseTaskSettings requestTaskSettings, + InputType requestInputType + ) { + var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType); + var returnToken = requestTaskSettings.isReturnToken() != null + ? requestTaskSettings.isReturnToken() + : originalSettings.isReturnToken(); + return new AlibabaCloudSearchSparseTaskSettings(inputTypeToUse, returnToken); + } + + private static InputType getValidInputType( + AlibabaCloudSearchSparseTaskSettings originalSettings, + AlibabaCloudSearchSparseTaskSettings requestTaskSettings, + InputType requestInputType + ) { + InputType inputTypeToUse = originalSettings.inputType; + + if (VALID_REQUEST_VALUES.contains(requestInputType)) { + inputTypeToUse = requestInputType; + } else if (requestTaskSettings.inputType != null) { + inputTypeToUse = requestTaskSettings.inputType; + } + + return inputTypeToUse; + } + + private final InputType inputType; + private final Boolean returnToken; + + public AlibabaCloudSearchSparseTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean()); + } + + public AlibabaCloudSearchSparseTaskSettings(@Nullable InputType inputType, Boolean returnToken) { + validateInputType(inputType); + this.inputType = inputType; + this.returnToken = returnToken; + } + + private static void validateInputType(InputType inputType) { + if (inputType == null) { + return; + } + + assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (inputType != null) { + builder.field(INPUT_TYPE, inputType); + } + if (returnToken != null) { + builder.field(RETURN_TOKEN, returnToken); + } + builder.endObject(); + return builder; + } + + public InputType getInputType() { + return inputType; + } + + public Boolean isReturnToken() { + return returnToken; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(inputType); + out.writeOptionalBoolean(returnToken); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AlibabaCloudSearchSparseTaskSettings that = (AlibabaCloudSearchSparseTaskSettings) o; + return Objects.equals(inputType, that.inputType) && Objects.equals(returnToken, that.returnToken); + } + + @Override + public int hashCode() { + return Objects.hash(inputType, returnToken); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java index d275d00373cbe..055b4581e067b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java @@ -14,4 +14,8 @@ public class InputTypeTests extends ESTestCase { public static InputType randomWithoutUnspecified() { return randomFrom(InputType.INGEST, InputType.SEARCH, InputType.CLUSTERING, InputType.CLASSIFICATION); } + + public static InputType randomWithIngestAndSearch() { + return randomFrom(InputType.INGEST, InputType.SEARCH); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..6aaab219c331d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestEntityTests.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class AlibabaCloudSearchEmbeddingsRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new AlibabaCloudSearchEmbeddingsRequestEntity( + List.of("abc"), + new AlibabaCloudSearchEmbeddingsTaskSettings(InputType.INGEST) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"input_type":"document"}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = new AlibabaCloudSearchEmbeddingsRequestEntity(List.of("abc"), AlibabaCloudSearchEmbeddingsTaskSettings.EMPTY_SETTINGS); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"]}""")); + } + + public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows( + AssertionError.class, + () -> AlibabaCloudSearchEmbeddingsRequestEntity.covertToString(InputType.UNSPECIFIED) + ); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..378401f589b19 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchEmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest() throws IOException { + var request = createRequest( + List.of("abc"), + AlibabaCloudSearchEmbeddingsModelTests.createModel( + "embedding_test", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("embeddings_test", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ) + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + MatcherAssert.assertThat( + httpPost.getURI().toString(), + is("https://host/v3/openapi/workspaces/default/text-embedding/embeddings_test") + ); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc")))); + } + + public static AlibabaCloudSearchEmbeddingsRequest createRequest(List input, AlibabaCloudSearchEmbeddingsModel model) { + var account = new AlibabaCloudSearchAccount(model.getSecretSettings().apiKey()); + return new AlibabaCloudSearchEmbeddingsRequest(account, input, model); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java new file mode 100644 index 0000000000000..8f981d64d36eb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class AlibabaCloudSearchRerankRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new AlibabaCloudSearchRerankRequestEntity("query", List.of("abc"), new AlibabaCloudSearchRerankTaskSettings()); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"query":"query","docs":["abc"]}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntityTests.java new file mode 100644 index 0000000000000..6ae209bc3c6f1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestEntityTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class AlibabaCloudSearchSparseRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new AlibabaCloudSearchSparseRequestEntity( + List.of("abc"), + new AlibabaCloudSearchSparseTaskSettings(InputType.INGEST, true) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"input_type":"document","return_token":true}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = new AlibabaCloudSearchSparseRequestEntity(List.of("abc"), AlibabaCloudSearchSparseTaskSettings.EMPTY_SETTINGS); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"]}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestTests.java new file mode 100644 index 0000000000000..74fc225820641 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchSparseRequestTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModelTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchSparseRequestTests extends ESTestCase { + public void testCreateRequest() throws IOException { + var request = createRequest( + List.of("abc"), + AlibabaCloudSearchSparseModelTests.createModel( + "embedding_test", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchSparseServiceSettingsTests.getServiceSettingsMap("embeddings_test", "host", "default"), + AlibabaCloudSearchSparseTaskSettingsTests.getTaskSettingsMap(null, null), + getSecretSettingsMap("secret") + ) + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + MatcherAssert.assertThat( + httpPost.getURI().toString(), + is("https://host/v3/openapi/workspaces/default/text-sparse-embedding/embeddings_test") + ); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc")))); + } + + public static AlibabaCloudSearchSparseRequest createRequest(List input, AlibabaCloudSearchSparseModel model) { + var account = new AlibabaCloudSearchAccount(model.getSecretSettings().apiKey()); + return new AlibabaCloudSearchSparseRequest(account, input, model); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..33fa6a2a542cb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchEmbeddingsResponseEntityTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AlibabaCloudSearchEmbeddingsResponseEntityTests extends ESTestCase { + public void testFromResponse_CreatesResultsForASingleItem() throws IOException, URISyntaxException { + String responseJson = """ + { + "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4", + "latency": 38, + "usage": { + "token_count": 3072 + }, + "result": { + "embeddings": [ + { + "index": 0, + "embedding": [ + -0.02868066355586052, + 0.022033605724573135 + ] + } + ] + } + } + """; + + AlibabaCloudSearchRequest request = mock(AlibabaCloudSearchRequest.class); + URI uri = new URI("mock_uri"); + when(request.getURI()).thenReturn(uri); + + InferenceTextEmbeddingFloatResults parsedResults = AlibabaCloudSearchEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding( + new float[] { -0.02868066355586052f, 0.022033605724573135f } + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntityTests.java new file mode 100644 index 0000000000000..a03349c66b6d5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchErrorResponseEntityTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.mockito.Mockito.mock; + +public class AlibabaCloudSearchErrorResponseEntityTests extends ESTestCase { + public void testFromResponse() { + String responseJson = """ + { + "request_id": "651B3087-8A07-4BF3-B931-9C4E7B60F52D", + "latency": 0, + "code": "InvalidParameter", + "message": "JSON parse error: Cannot deserialize value of type `InputType` from String \\"xxx\\"" + } + """; + + AlibabaCloudSearchErrorResponseEntity errorMessage = AlibabaCloudSearchErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorMessage); + assertEquals("JSON parse error: Cannot deserialize value of type `InputType` from String \"xxx\"", errorMessage.getErrorMessage()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntityTests.java new file mode 100644 index 0000000000000..bebc8bb66f207 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchRerankResponseEntityTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class AlibabaCloudSearchRerankResponseEntityTests extends ESTestCase { + + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + InferenceServiceResults parsedResults = AlibabaCloudSearchRerankResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + List expected = responseLiteralDocs(); + for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { + assertThat(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), is(expected.get(i).index())); + } + } + + private final String responseLiteral = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "index":1, + "score": 1.37 + }, + { + "index":0, + "score": -0.3 + } + ] + } + } + """; + + private ArrayList responseLiteralDocs() { + var list = new ArrayList(); + + list.add(new RankedDocsResults.RankedDoc(1, 1.37F, null)); + list.add(new RankedDocsResults.RankedDoc(0, -0.3F, null)); + return list; + }; +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntityTests.java new file mode 100644 index 0000000000000..a6d3a4b77d74f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchSparseResponseEntityTests.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AlibabaCloudSearchSparseResponseEntityTests extends ESTestCase { + public void testFromResponse_CreatesResultsForASingleItem() throws IOException, URISyntaxException { + String responseJson = """ + { + "request_id": "DDC4306F-xxxx-xxxx-xxxx-92C5CEA756A0", + "latency": 25, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "embedding": [ + { + "token_id": 6, + "weight": 0.1014404296875 + }, + { + "token_id": 163040, + "weight": 0.2841796875 + }, + { + "token_id": 354, + "weight": 0.1431884765625 + } + ] + } + ] + } + } + """; + + AlibabaCloudSearchRequest request = mock(AlibabaCloudSearchRequest.class); + URI uri = new URI("mock_uri"); + when(request.getURI()).thenReturn(uri); + + SparseEmbeddingResults parsedResults = AlibabaCloudSearchSparseResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + new SparseEmbeddingResults.Embedding( + List.of( + new WeightedToken("6", 0.1014404296875f), + new WeightedToken("163040", 0.2841796875f), + new WeightedToken("354", 0.1431884765625f) + ), + false + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettingsTests.java new file mode 100644 index 0000000000000..d7965a38c845b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettingsTests.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchServiceSettingsTests extends AbstractWireSerializingTestCase { + /** + * The created settings can have a url set to null. + */ + public static AlibabaCloudSearchServiceSettings createRandom() { + var model = randomAlphaOfLength(15); + String host = randomAlphaOfLength(15); + String workspaceName = randomAlphaOfLength(10); + String httpSchema = "https"; + return new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, RateLimitSettingsTests.createRandom()); + } + + public void testFromMap() throws URISyntaxException { + var model = "model"; + var host = "host"; + var workspaceName = "default"; + var httpSchema = "https"; + var serviceSettings = AlibabaCloudSearchServiceSettings.fromMap( + new HashMap<>( + Map.of( + AlibabaCloudSearchServiceSettings.SERVICE_ID, + model, + AlibabaCloudSearchServiceSettings.HOST, + host, + AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, + workspaceName, + AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME, + httpSchema + ) + ), + null + ); + + MatcherAssert.assertThat(serviceSettings, is(new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, null))); + } + + public void testFromMap_WithRateLimit() { + var model = "model"; + var host = "host"; + var workspaceName = "default"; + var httpSchema = "https"; + var serviceSettings = AlibabaCloudSearchServiceSettings.fromMap( + new HashMap<>( + Map.of( + AlibabaCloudSearchServiceSettings.SERVICE_ID, + model, + AlibabaCloudSearchServiceSettings.HOST, + host, + AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, + workspaceName, + AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME, + httpSchema, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)) + ) + ), + null + ); + + MatcherAssert.assertThat( + serviceSettings, + is(new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, new RateLimitSettings(3))) + ); + } + + public void testXContent() throws IOException { + var entity = new AlibabaCloudSearchServiceSettings("model_id_name", "host_name", "workspace_name", null, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"service_id":"model_id_name","host":"host_name","workspace":"workspace_name","rate_limit":{"requests_per_minute":1000}}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return AlibabaCloudSearchServiceSettings::new; + } + + @Override + protected AlibabaCloudSearchServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AlibabaCloudSearchServiceSettings mutateInstance(AlibabaCloudSearchServiceSettings instance) throws IOException { + return null; + } + + public static Map getServiceSettingsMap(String serviceId, String host, String workspaceName) { + var map = new HashMap(); + map.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, serviceId); + map.put(AlibabaCloudSearchServiceSettings.HOST, host); + map.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, workspaceName); + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java new file mode 100644 index 0000000000000..cc70b61226fe3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; + +public class AlibabaCloudSearchServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException { + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testCheckModelConfig() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool)) { + @Override + public void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.028680f, 0.022033f })) + ); + + listener.onResponse(results); + } + }) { + Map serviceSettingsMap = new HashMap<>(); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); + serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + + Map taskSettingsMap = new HashMap<>(); + + Map secretSettingsMap = new HashMap<>(); + secretSettingsMap.put("api_key", "secret"); + + var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( + "service", + TaskType.TEXT_EMBEDDING, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + Map expectedServiceSettingsMap = new HashMap<>(); + expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); + expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); + expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); + expectedServiceSettingsMap.put(ServiceFields.SIMILARITY, "DOT_PRODUCT"); + expectedServiceSettingsMap.put(ServiceFields.DIMENSIONS, 2); + + Map expectedTaskSettingsMap = new HashMap<>(); + + Map expectedSecretSettingsMap = new HashMap<>(); + expectedSecretSettingsMap.put("api_key", "secret"); + + var expectedModel = AlibabaCloudSearchEmbeddingsModelTests.createModel( + "service", + TaskType.TEXT_EMBEDDING, + expectedServiceSettingsMap, + expectedTaskSettingsMap, + expectedSecretSettingsMap + ); + + MatcherAssert.assertThat(result, is(expectedModel)); + } + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java new file mode 100644 index 0000000000000..fca0ee11e5c78 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchEmbeddingsModelTests extends ESTestCase { + public void testOverride() { + AlibabaCloudSearchEmbeddingsTaskSettings taskSettings = AlibabaCloudSearchEmbeddingsTaskSettingsTests.createRandom(); + var model = createModel( + "service", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchEmbeddingsServiceSettingsTests.createRandom(), + taskSettings, + DefaultSecretSettingsTests.createRandom() + ); + + var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, Map.of(), taskSettings.getInputType()); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public static AlibabaCloudSearchEmbeddingsModel createModel( + String modelId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets + ) { + return new AlibabaCloudSearchEmbeddingsModel( + modelId, + taskType, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettings, + taskSettings, + secrets, + null + ); + } + + public static AlibabaCloudSearchEmbeddingsModel createModel( + String modelId, + TaskType taskType, + AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings, + AlibabaCloudSearchEmbeddingsTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + return new AlibabaCloudSearchEmbeddingsModel( + modelId, + taskType, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettings, + taskSettings, + secretSettings + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..815e6d0311195 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettingsTests.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase< + AlibabaCloudSearchEmbeddingsServiceSettings> { + public static AlibabaCloudSearchEmbeddingsServiceSettings createRandom() { + var commonSettings = AlibabaCloudSearchServiceSettingsTests.createRandom(); + var similarity = SimilarityMeasure.DOT_PRODUCT; + var dims = 1536; + var maxInputTokens = 512; + return new AlibabaCloudSearchEmbeddingsServiceSettings(commonSettings, similarity, dims, maxInputTokens); + } + + public void testFromMap() { + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var host = "host"; + var workspaceName = "default"; + var httpSchema = "https"; + var serviceSettings = AlibabaCloudSearchEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + AlibabaCloudSearchServiceSettings.HOST, + host, + AlibabaCloudSearchServiceSettings.SERVICE_ID, + model, + AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, + workspaceName, + AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME, + httpSchema + ) + ), + null + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new AlibabaCloudSearchEmbeddingsServiceSettings( + new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, null), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens + ) + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return AlibabaCloudSearchEmbeddingsServiceSettings::new; + } + + @Override + protected AlibabaCloudSearchEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AlibabaCloudSearchEmbeddingsServiceSettings mutateInstance(AlibabaCloudSearchEmbeddingsServiceSettings instance) + throws IOException { + return null; + } + + public static Map getServiceSettingsMap(String serviceId, String host, String workspaceName) { + return AlibabaCloudSearchServiceSettingsTests.getServiceSettingsMap(serviceId, host, workspaceName); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..244685d8e9833 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettingsTests.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch; +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase< + AlibabaCloudSearchEmbeddingsTaskSettings> { + public static AlibabaCloudSearchEmbeddingsTaskSettings createRandom() { + var inputType = randomBoolean() ? randomWithIngestAndSearch() : null; + + return new AlibabaCloudSearchEmbeddingsTaskSettings(inputType); + } + + public void testFromMap() { + MatcherAssert.assertThat( + AlibabaCloudSearchEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(AlibabaCloudSearchEmbeddingsTaskSettings.INPUT_TYPE, "ingest")) + ), + is(new AlibabaCloudSearchEmbeddingsTaskSettings(InputType.INGEST)) + ); + } + + public void testFromMap_WhenInputTypeIsNull() { + InputType inputType = null; + MatcherAssert.assertThat( + AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())), + is(new AlibabaCloudSearchEmbeddingsTaskSettings(inputType)) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return AlibabaCloudSearchEmbeddingsTaskSettings::new; + } + + @Override + protected AlibabaCloudSearchEmbeddingsTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AlibabaCloudSearchEmbeddingsTaskSettings mutateInstance(AlibabaCloudSearchEmbeddingsTaskSettings instance) + throws IOException { + return null; + } + + public static Map getTaskSettingsMap(@Nullable InputType inputType) { + var map = new HashMap(); + + if (inputType != null) { + map.put(AlibabaCloudSearchEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString()); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java new file mode 100644 index 0000000000000..4e9179b66c36f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchSparseModelTests extends ESTestCase { + public void testOverride() { + AlibabaCloudSearchSparseTaskSettings taskSettings = AlibabaCloudSearchSparseTaskSettingsTests.createRandom(); + var model = createModel( + "service", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchSparseServiceSettingsTests.createRandom(), + taskSettings, + DefaultSecretSettingsTests.createRandom() + ); + + var overriddenModel = AlibabaCloudSearchSparseModel.of(model, Map.of(), taskSettings.getInputType()); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public static AlibabaCloudSearchSparseModel createModel( + String modelId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets + ) { + return new AlibabaCloudSearchSparseModel( + modelId, + taskType, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettings, + taskSettings, + secrets, + null + ); + } + + public static AlibabaCloudSearchSparseModel createModel( + String modelId, + TaskType taskType, + AlibabaCloudSearchSparseServiceSettings serviceSettings, + AlibabaCloudSearchSparseTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + return new AlibabaCloudSearchSparseModel( + modelId, + taskType, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettings, + taskSettings, + secretSettings + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettingsTests.java new file mode 100644 index 0000000000000..8dc635a52f06f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettingsTests.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchSparseServiceSettingsTests extends AbstractWireSerializingTestCase { + public static AlibabaCloudSearchSparseServiceSettings createRandom() { + var commonSettings = AlibabaCloudSearchServiceSettingsTests.createRandom(); + return new AlibabaCloudSearchSparseServiceSettings(commonSettings); + } + + public void testFromMap() { + var model = "model"; + var host = "host"; + var workspaceName = "default"; + var httpSchema = "https"; + var serviceSettings = AlibabaCloudSearchSparseServiceSettings.fromMap( + new HashMap<>( + Map.of( + AlibabaCloudSearchServiceSettings.HOST, + host, + AlibabaCloudSearchServiceSettings.SERVICE_ID, + model, + AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, + workspaceName, + AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME, + httpSchema + ) + ), + null + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new AlibabaCloudSearchSparseServiceSettings( + new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, null) + ) + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return AlibabaCloudSearchSparseServiceSettings::new; + } + + @Override + protected AlibabaCloudSearchSparseServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AlibabaCloudSearchSparseServiceSettings mutateInstance(AlibabaCloudSearchSparseServiceSettings instance) throws IOException { + return null; + } + + public static Map getServiceSettingsMap(String serviceId, String host, String workspaceName) { + return AlibabaCloudSearchServiceSettingsTests.getServiceSettingsMap(serviceId, host, workspaceName); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java new file mode 100644 index 0000000000000..b16d96f9a081b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettingsTests.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch; +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchSparseTaskSettingsTests extends AbstractWireSerializingTestCase { + public static AlibabaCloudSearchSparseTaskSettings createRandom() { + var inputType = randomBoolean() ? randomWithIngestAndSearch() : null; + var returnToken = randomBoolean(); + + return new AlibabaCloudSearchSparseTaskSettings(inputType, returnToken); + } + + public void testFromMap() { + MatcherAssert.assertThat( + AlibabaCloudSearchSparseTaskSettings.fromMap(new HashMap<>(Map.of(AlibabaCloudSearchSparseTaskSettings.INPUT_TYPE, "ingest"))), + is(new AlibabaCloudSearchSparseTaskSettings(InputType.INGEST, null)) + ); + } + + public void testFromMap_WhenInputTypeIsNull() { + InputType inputType = null; + MatcherAssert.assertThat( + AlibabaCloudSearchSparseTaskSettings.fromMap(new HashMap<>(Map.of())), + is(new AlibabaCloudSearchSparseTaskSettings(inputType, null)) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return AlibabaCloudSearchSparseTaskSettings::new; + } + + @Override + protected AlibabaCloudSearchSparseTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AlibabaCloudSearchSparseTaskSettings mutateInstance(AlibabaCloudSearchSparseTaskSettings instance) throws IOException { + return null; + } + + public static Map getTaskSettingsMap(@Nullable InputType inputType, @Nullable Boolean returnToken) { + var map = new HashMap(); + + if (inputType != null) { + map.put(AlibabaCloudSearchSparseTaskSettings.INPUT_TYPE, inputType.toString()); + } + + if (returnToken != null) { + map.put(AlibabaCloudSearchSparseTaskSettings.RETURN_TOKEN, returnToken); + } + + return map; + } +}