From 09ca56049f0a4450173ed07cd15394c2c59efa52 Mon Sep 17 00:00:00 2001 From: Huaixinww <141887897+Huaixinww@users.noreply.github.com> Date: Thu, 12 Sep 2024 21:14:25 +0800 Subject: [PATCH] [Inference API] Add Completion Inference API for Alibaba Cloud AI Search Model (#112512) --- docs/changelog/112512.yaml | 5 + .../InferenceNamedWriteablesProvider.java | 16 ++ .../AlibabaCloudSearchActionCreator.java | 8 + .../AlibabaCloudSearchActionVisitor.java | 3 + .../AlibabaCloudSearchCompletionAction.java | 88 +++++++++++ ...baCloudSearchCompletionRequestManager.java | 76 ++++++++++ .../AlibabaCloudSearchUtils.java | 1 + .../AlibabaCloudSearchCompletionRequest.java | 112 ++++++++++++++ ...abaCloudSearchCompletionRequestEntity.java | 61 ++++++++ ...baCloudSearchCompletionResponseEntity.java | 93 ++++++++++++ .../AlibabaCloudSearchService.java | 10 ++ .../AlibabaCloudSearchCompletionModel.java | 100 +++++++++++++ ...aCloudSearchCompletionServiceSettings.java | 97 +++++++++++++ ...babaCloudSearchCompletionTaskSettings.java | 137 ++++++++++++++++++ ...oudSearchCompletionRequestEntityTests.java | 81 +++++++++++ ...babaCloudSearchCompletionRequestTests.java | 64 ++++++++ ...udSearchCompletionResponseEntityTests.java | 53 +++++++ ...libabaCloudSearchCompletionModelTests.java | 70 +++++++++ ...dSearchCompletionServiceSettingsTests.java | 97 +++++++++++++ ...loudSearchCompletionTaskSettingsTests.java | 61 ++++++++ 20 files changed, 1233 insertions(+) create mode 100644 docs/changelog/112512.yaml create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/completion/AlibabaCloudSearchCompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/completion/AlibabaCloudSearchCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchCompletionResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchCompletionRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchCompletionResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettingsTests.java diff --git a/docs/changelog/112512.yaml b/docs/changelog/112512.yaml new file mode 100644 index 0000000000000..a9812784ccfca --- /dev/null +++ b/docs/changelog/112512.yaml @@ -0,0 +1,5 @@ +pr: 112512 +summary: Add Completion Inference API for Alibaba Cloud AI Search Model +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 62c2c5fd61992..3f555a076b0a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -27,6 +27,8 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankServiceSettings; @@ -543,6 +545,20 @@ private static void addAlibabaCloudSearchNamedWriteables(List taskSettings) { + var overriddenModel = AlibabaCloudSearchCompletionModel.of(model, taskSettings); + + return new AlibabaCloudSearchCompletionAction(sender, overriddenModel, serviceComponents); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java index 69ae903c7b38f..b158ee4a780c0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java @@ -9,6 +9,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; @@ -21,4 +22,6 @@ public interface AlibabaCloudSearchActionVisitor { ExecutableAction create(AlibabaCloudSearchSparseModel model, Map taskSettings, InputType inputType); ExecutableAction create(AlibabaCloudSearchRerankModel model, Map taskSettings); + + ExecutableAction create(AlibabaCloudSearchCompletionModel model, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionAction.java new file mode 100644 index 0000000000000..dc1d31f3e59df --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionAction.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel; + +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AlibabaCloudSearchCompletionAction implements ExecutableAction { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionAction.class); + + private final AlibabaCloudSearchAccount account; + private final AlibabaCloudSearchCompletionModel model; + private final String failedToSendRequestErrorMessage; + private final Sender sender; + private final AlibabaCloudSearchCompletionRequestManager requestCreator; + + public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompletionModel model, ServiceComponents serviceComponents) { + this.model = Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey()); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search completion"); + this.requestCreator = AlibabaCloudSearchCompletionRequestManager.of(account, model, serviceComponents.threadPool()); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + if (inferenceInputs instanceof DocumentsOnlyInput == false) { + listener.onFailure( + new ElasticsearchStatusException( + format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION), + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + return; + } + + var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs; + if (docsOnlyInput.getInputs().size() % 2 == 0) { + listener.onFailure( + new ElasticsearchStatusException( + "Alibaba Completion's inputs must be an odd number. The last input is the current query, " + + "all preceding inputs are the completion history as pairs of user input and the assistant's response.", + RestStatus.BAD_REQUEST + ) + ); + return; + } + + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + sender.send(requestCreator, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java new file mode 100644 index 0000000000000..a0a44e62f9f73 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion.AlibabaCloudSearchCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class AlibabaCloudSearchCompletionRequestManager extends AlibabaCloudSearchRequestManager { + private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + private static ResponseHandler createCompletionHandler() { + return new AlibabaCloudSearchResponseHandler( + "alibaba cloud search completion", + AlibabaCloudSearchCompletionResponseEntity::fromResponse + ); + } + + public static AlibabaCloudSearchCompletionRequestManager of( + AlibabaCloudSearchAccount account, + AlibabaCloudSearchCompletionModel model, + ThreadPool threadPool + ) { + return new AlibabaCloudSearchCompletionRequestManager( + Objects.requireNonNull(account), + Objects.requireNonNull(model), + Objects.requireNonNull(threadPool) + ); + } + + private final AlibabaCloudSearchCompletionModel model; + + private final AlibabaCloudSearchAccount account; + + private AlibabaCloudSearchCompletionRequestManager( + AlibabaCloudSearchAccount account, + AlibabaCloudSearchCompletionModel model, + ThreadPool threadPool + ) { + super(threadPool, model); + this.account = Objects.requireNonNull(account); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List input = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java index 7d671471976f5..ba7c74848b3be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchUtils.java @@ -15,4 +15,5 @@ public class AlibabaCloudSearchUtils { public static final String TEXT_EMBEDDING_PATH = "text-embedding"; public static final String SPARSE_EMBEDDING_PATH = "text-sparse-embedding"; public static final String RERANK_PATH = "ranker"; + public static final String COMPLETION_PATH = "text-generation"; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/completion/AlibabaCloudSearchCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/completion/AlibabaCloudSearchCompletionRequest.java new file mode 100644 index 0000000000000..12c2574003083 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/completion/AlibabaCloudSearchCompletionRequest.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class AlibabaCloudSearchCompletionRequest extends AlibabaCloudSearchRequest { + private final AlibabaCloudSearchAccount account; + private final List input; + private final URI uri; + private final AlibabaCloudSearchCompletionTaskSettings taskSettings; + private final String model; + private final String host; + private final String workspaceName; + private final String httpSchema; + private final String inferenceEntityId; + + public AlibabaCloudSearchCompletionRequest( + AlibabaCloudSearchAccount account, + List input, + AlibabaCloudSearchCompletionModel completionModel + ) { + Objects.requireNonNull(completionModel); + + this.account = Objects.requireNonNull(account); + this.input = Objects.requireNonNull(input); + taskSettings = completionModel.getTaskSettings(); + model = completionModel.getServiceSettings().getCommonSettings().modelId(); + host = completionModel.getServiceSettings().getCommonSettings().getHost(); + workspaceName = completionModel.getServiceSettings().getCommonSettings().getWorkspaceName(); + httpSchema = completionModel.getServiceSettings().getCommonSettings().getHttpSchema() != null + ? completionModel.getServiceSettings().getCommonSettings().getHttpSchema() + : "https"; + uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri); + inferenceEntityId = completionModel.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new AlibabaCloudSearchCompletionRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme(httpSchema) + .setHost(host) + .setPathSegments( + AlibabaCloudSearchUtils.VERSION_3, + AlibabaCloudSearchUtils.OPENAPI_PATH, + AlibabaCloudSearchUtils.WORKSPACE_PATH, + workspaceName, + AlibabaCloudSearchUtils.COMPLETION_PATH, + model + ) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/completion/AlibabaCloudSearchCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/completion/AlibabaCloudSearchCompletionRequestEntity.java new file mode 100644 index 0000000000000..32e95f842440a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/completion/AlibabaCloudSearchCompletionRequestEntity.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record AlibabaCloudSearchCompletionRequestEntity( + List messages, + AlibabaCloudSearchCompletionTaskSettings taskSettings, + @Nullable String model +) implements ToXContentObject { + + private static final String MESSAGE = "messages"; + private static final String PARAMETERS = "parameters"; + private static final String ROLE_FIELD = "role"; + private static final String ROLE_USER = "user"; + private static final String ROLE_ASSISTANT = "assistant"; + private static final String CONTENT_FIELD = "content"; + + public AlibabaCloudSearchCompletionRequestEntity { + Objects.requireNonNull(messages); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + builder.startArray(MESSAGE); + { + for (int i = 0; i < messages.size(); i++) { + builder.startObject(); + { + String roleValue = i % 2 == 0 ? ROLE_USER : ROLE_ASSISTANT; + builder.field(ROLE_FIELD, roleValue); + builder.field(CONTENT_FIELD, messages.get(i)); + } + builder.endObject(); + } + } + builder.endArray(); + if (taskSettings.getParameters() != null) { + builder.field(PARAMETERS, taskSettings.getParameters()); + } + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchCompletionResponseEntity.java new file mode 100644 index 0000000000000..5ce5809519efe --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchCompletionResponseEntity.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class AlibabaCloudSearchCompletionResponseEntity extends AlibabaCloudSearchResponseEntity { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in AlibabaCloud Search completion response"; + + /** + * Parses the AlibabaCloud Search embedding json response. + * For a request like: + * + *
+     * 
+     * {
+     *     "messages": [
+     *         {
+     *             "role": "system",
+     *             "content": "你是一个机器人助手"
+     *         },
+     *         {
+     *             "role": "user",
+     *             "content": "河南的省会是哪里"
+     *         },
+     *         {
+     *             "role": "assistant",
+     *             "content": "郑州"
+     *         },
+     *         {
+     *             "role": "user",
+     *             "content": "那里有什么好玩的"
+     *         }
+     *     ],
+     *     "stream": false
+     * }
+     * 
+     * 
+ * + * The response would look like: + * + *
+     * 
+     * {
+     *   "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+     *   "latency": 564.903929,
+     *   "result": {
+     *     "text":"郑州是一个历史文化悠久且现代化的城市,有很多好玩的地方。以下是一些推荐的旅游景点:
+     *     嵩山少林寺:作为少林武术的发源地,嵩山少林寺一直以来都是游客向往的地方。在这里,你可以欣赏到精彩的武术表演,领略少林功夫的魅力。
+     *     黄河游览区:黄河是中华民族的母亲河,而在郑州,你可以乘坐游船观赏黄河的多种风情,感受大河之美。
+     *     郑州动物园:这是一个适合全家游玩的景点,拥有各种珍稀动物,如大熊猫、金丝猴等,让孩子们近距离接触动物,增长见识。
+     *     郑州博物馆:如果你对历史文化感兴趣,那么郑州博物馆是一个不错的选择。这里收藏了大量珍贵的文物,展示了郑州地区的历史变迁和文化传承。
+     *     郑州世纪公园:这是一个大型的城市公园,拥有美丽的湖泊、花园和休闲设施。在这里,你可以进行散步、慢跑等户外活动,享受大自然的宁静与和谐。
+     *     以上只是郑州众多好玩地方的一部分,实际上郑州还有很多其他值得一游的景点。希望你在郑州的旅行能够愉快!"
+     *   }
+     *   "usage": {
+     *       "output_tokens": 6320,
+     *       "input_tokens": 35,
+     *       "total_tokens": 6355,
+     *   }
+     *
+     * }
+     * 
+     * 
+ */ + + public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException { + return fromResponse(request, response, jsonParser -> { + positionParserAtTokenAfterField(jsonParser, "text", FAILED_TO_FIND_FIELD_TEMPLATE); + + XContentParser.Token contentToken = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser); + String content = jsonParser.text(); + + return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content))); + }); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 0c48c99b4b81e..173f2bbf131b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; @@ -149,6 +150,15 @@ private static AlibabaCloudSearchModel createModel( secretSettings, context ); + case COMPLETION -> new AlibabaCloudSearchCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionModel.java new file mode 100644 index 0000000000000..f5140981bfbe0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionModel.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class AlibabaCloudSearchCompletionModel extends AlibabaCloudSearchModel { + public static AlibabaCloudSearchCompletionModel of(AlibabaCloudSearchCompletionModel model, Map taskSettings) { + var requestTaskSettings = AlibabaCloudSearchCompletionTaskSettings.fromMap(taskSettings); + return new AlibabaCloudSearchCompletionModel( + model, + AlibabaCloudSearchCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings) + ); + } + + public AlibabaCloudSearchCompletionModel( + String modelId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + modelId, + taskType, + service, + AlibabaCloudSearchCompletionServiceSettings.fromMap(serviceSettings, context), + AlibabaCloudSearchCompletionTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + AlibabaCloudSearchCompletionModel( + String modelId, + TaskType taskType, + String service, + AlibabaCloudSearchCompletionServiceSettings serviceSettings, + AlibabaCloudSearchCompletionTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings.getCommonSettings() + ); + } + + private AlibabaCloudSearchCompletionModel( + AlibabaCloudSearchCompletionModel model, + AlibabaCloudSearchCompletionTaskSettings taskSettings + ) { + super(model, taskSettings); + } + + public AlibabaCloudSearchCompletionModel( + AlibabaCloudSearchCompletionModel model, + AlibabaCloudSearchCompletionServiceSettings serviceSettings + ) { + super(model, serviceSettings); + } + + @Override + public AlibabaCloudSearchCompletionServiceSettings getServiceSettings() { + return (AlibabaCloudSearchCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public AlibabaCloudSearchCompletionTaskSettings getTaskSettings() { + return (AlibabaCloudSearchCompletionTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettings.java new file mode 100644 index 0000000000000..631ec8a8648e8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettings.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class AlibabaCloudSearchCompletionServiceSettings implements ServiceSettings { + public static final String NAME = "alibabacloud_search_completion_service_settings"; + + public static AlibabaCloudSearchCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = AlibabaCloudSearchServiceSettings.fromMap(map, context); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AlibabaCloudSearchCompletionServiceSettings(commonServiceSettings); + } + + private final AlibabaCloudSearchServiceSettings commonSettings; + + public AlibabaCloudSearchCompletionServiceSettings(AlibabaCloudSearchServiceSettings commonSettings) { + this.commonSettings = commonSettings; + } + + public AlibabaCloudSearchCompletionServiceSettings(StreamInput in) throws IOException { + commonSettings = new AlibabaCloudSearchServiceSettings(in); + } + + public AlibabaCloudSearchServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + commonSettings.toXContentFragment(builder, params); + builder.endObject(); + return builder; + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AlibabaCloudSearchCompletionServiceSettings that = (AlibabaCloudSearchCompletionServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java new file mode 100644 index 0000000000000..6fb726e60835a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Defines the task settings for the AlibabaCloudSearch completion service. + * + *

+ * + * See api docs for details. + *

+ */ +public class AlibabaCloudSearchCompletionTaskSettings implements TaskSettings { + public static final String NAME = "alibabacloud_search_completion_task_settings"; + public static final String PARAMETERS = "parameters"; + + static final AlibabaCloudSearchCompletionTaskSettings EMPTY_SETTINGS = new AlibabaCloudSearchCompletionTaskSettings( + (Map) null + ); + + @SuppressWarnings("unchecked") + public static AlibabaCloudSearchCompletionTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Map parameters = ServiceUtils.removeAsType(map, PARAMETERS, Map.class, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(parameters); + } + + /** + * Creates a new {@link AlibabaCloudSearchCompletionTaskSettings} + * by preferring non-null fields from the request settings over the original settings. + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link AlibabaCloudSearchCompletionTaskSettings} + */ + public static AlibabaCloudSearchCompletionTaskSettings of( + AlibabaCloudSearchCompletionTaskSettings originalSettings, + AlibabaCloudSearchCompletionTaskSettings requestTaskSettings + ) { + if (originalSettings != null + && originalSettings.parameters != null + && requestTaskSettings != null + && requestTaskSettings.parameters != null) { + var copy = new HashMap<>(originalSettings.parameters); + requestTaskSettings.parameters.forEach((key, value) -> copy.merge(key, value, (originalValue, requestValue) -> requestValue)); + return new AlibabaCloudSearchCompletionTaskSettings(copy); + } else { + return new AlibabaCloudSearchCompletionTaskSettings( + requestTaskSettings.getParameters() != null ? requestTaskSettings.getParameters() : originalSettings.getParameters() + ); + } + } + + public static AlibabaCloudSearchCompletionTaskSettings of(Map parameters) { + return new AlibabaCloudSearchCompletionTaskSettings(parameters); + } + + private final Map parameters; + + public AlibabaCloudSearchCompletionTaskSettings(StreamInput in) throws IOException { + this(in.readGenericMap()); + } + + public AlibabaCloudSearchCompletionTaskSettings(@Nullable Map parameters) { + this.parameters = parameters; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (parameters != null) { + builder.field(PARAMETERS, parameters); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeGenericMap(parameters); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AlibabaCloudSearchCompletionTaskSettings that = (AlibabaCloudSearchCompletionTaskSettings) o; + return Objects.equals(parameters, that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(parameters); + } + + public Map getParameters() { + return parameters; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..5603218b1205a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchCompletionRequestEntityTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion.AlibabaCloudSearchCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class AlibabaCloudSearchCompletionRequestEntityTests extends ESTestCase { + public void testToXContent_WritesSingleMessage() throws IOException { + var entity = new AlibabaCloudSearchCompletionRequestEntity( + List.of("input"), + AlibabaCloudSearchCompletionTaskSettings.of(null), + null + ); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "messages": [ + { + "role": "user", + "content": "input" + } + ] + } + """)); + } + + public void testToXContent_WritesMultiMessages() throws IOException { + var entity = new AlibabaCloudSearchCompletionRequestEntity( + List.of("question1", "answer1", "question2", "answer2", "question3"), + AlibabaCloudSearchCompletionTaskSettings.of(null), + null + ); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "messages": [ + { + "role": "user", + "content": "question1" + }, + { + "role": "assistant", + "content": "answer1" + }, + { + "role": "user", + "content": "question2" + }, + { + "role": "assistant", + "content": "answer2" + }, + { + "role": "user", + "content": "question3" + } + ] + } + """)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchCompletionRequestTests.java new file mode 100644 index 0000000000000..0584f44f08145 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchCompletionRequestTests.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion.AlibabaCloudSearchCompletionRequest; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchCompletionRequestTests extends ESTestCase { + public void testCreateRequest() throws IOException { + var request = createRequest( + List.of("query"), + AlibabaCloudSearchCompletionModelTests.createModel( + "completion_test", + TaskType.COMPLETION, + AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"), + AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ) + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + MatcherAssert.assertThat( + httpPost.getURI().toString(), + is("https://host/v3/openapi/workspaces/default/text-generation/completion_test") + ); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("messages", List.of(Map.of("role", "user", "content", "query"))))); + } + + public static AlibabaCloudSearchCompletionRequest createRequest(List input, AlibabaCloudSearchCompletionModel model) { + var account = new AlibabaCloudSearchAccount(model.getSecretSettings().apiKey()); + return new AlibabaCloudSearchCompletionRequest(account, input, model); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchCompletionResponseEntityTests.java new file mode 100644 index 0000000000000..23b1569b76f38 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/alibabacloudsearch/AlibabaCloudSearchCompletionResponseEntityTests.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.alibabacloudsearch; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AlibabaCloudSearchCompletionResponseEntityTests extends ESTestCase { + public void testFromResponse_CreatesResponseEntityForText() throws IOException, URISyntaxException { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": { + "text":"result" + } + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355, + } + } + """; + + AlibabaCloudSearchRequest request = mock(AlibabaCloudSearchRequest.class); + URI uri = new URI("mock_uri"); + when(request.getURI()).thenReturn(uri); + + ChatCompletionResults chatCompletionResults = AlibabaCloudSearchCompletionResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionModelTests.java new file mode 100644 index 0000000000000..57218a5cf45a9 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionModelTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.hamcrest.MatcherAssert; + +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchCompletionModelTests extends ESTestCase { + public void testOverride() { + AlibabaCloudSearchCompletionTaskSettings taskSettings = AlibabaCloudSearchCompletionTaskSettingsTests.createRandom(); + var model = createModel( + "service", + TaskType.COMPLETION, + AlibabaCloudSearchCompletionServiceSettingsTests.createRandom(), + taskSettings, + null + ); + + var overriddenModel = AlibabaCloudSearchCompletionModel.of(model, Map.of()); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public static AlibabaCloudSearchCompletionModel createModel( + String modelId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets + ) { + return new AlibabaCloudSearchCompletionModel( + modelId, + taskType, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettings, + taskSettings, + secrets, + null + ); + } + + public static AlibabaCloudSearchCompletionModel createModel( + String modelId, + TaskType taskType, + AlibabaCloudSearchCompletionServiceSettings serviceSettings, + AlibabaCloudSearchCompletionTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + return new AlibabaCloudSearchCompletionModel( + modelId, + taskType, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettings, + taskSettings, + secretSettings + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..3a5f56c0b4247 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettingsTests.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< + AlibabaCloudSearchCompletionServiceSettings> { + public static AlibabaCloudSearchCompletionServiceSettings createRandom() { + var commonSettings = AlibabaCloudSearchServiceSettingsTests.createRandom(); + return new AlibabaCloudSearchCompletionServiceSettings(commonSettings); + } + + public void testFromMap() { + var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var host = "host"; + var workspaceName = "default"; + var httpSchema = "https"; + var serviceSettings = AlibabaCloudSearchCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + AlibabaCloudSearchServiceSettings.HOST, + host, + AlibabaCloudSearchServiceSettings.SERVICE_ID, + model, + AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, + workspaceName, + AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME, + httpSchema + ) + ), + null + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new AlibabaCloudSearchCompletionServiceSettings( + new AlibabaCloudSearchServiceSettings(model, host, workspaceName, httpSchema, null) + ) + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return AlibabaCloudSearchCompletionServiceSettings::new; + } + + @Override + protected AlibabaCloudSearchCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AlibabaCloudSearchCompletionServiceSettings mutateInstance(AlibabaCloudSearchCompletionServiceSettings instance) + throws IOException { + return createRandom(); + } + + public static Map getServiceSettingsMap(String serviceId, String host, String workspaceName) { + var map = new HashMap(); + map.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, serviceId); + map.put(AlibabaCloudSearchServiceSettings.HOST, host); + map.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, workspaceName); + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettingsTests.java new file mode 100644 index 0000000000000..63fdb38b33df2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettingsTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AlibabaCloudSearchCompletionTaskSettingsTests extends AbstractWireSerializingTestCase< + AlibabaCloudSearchCompletionTaskSettings> { + public static AlibabaCloudSearchCompletionTaskSettings createRandom() { + Map parameters = randomBoolean() ? Map.of() : null; + + return new AlibabaCloudSearchCompletionTaskSettings(parameters); + } + + public void testFromMap() { + MatcherAssert.assertThat( + AlibabaCloudSearchCompletionTaskSettings.fromMap(Map.of()), + is(new AlibabaCloudSearchCompletionTaskSettings((Map) null)) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return AlibabaCloudSearchCompletionTaskSettings::new; + } + + @Override + protected AlibabaCloudSearchCompletionTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AlibabaCloudSearchCompletionTaskSettings mutateInstance(AlibabaCloudSearchCompletionTaskSettings instance) + throws IOException { + return null; + } + + public static Map getTaskSettingsMap(@Nullable Map params) { + var map = new HashMap(); + + if (params != null) { + map.put(AlibabaCloudSearchCompletionTaskSettings.PARAMETERS, params); + } + + return map; + } +}