Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/112512.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 112512
summary: Add Completion Inference API for Alibaba Cloud AI Search Model
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankServiceSettings;
Expand Down Expand Up @@ -543,6 +545,20 @@ private static void addAlibabaCloudSearchNamedWriteables(List<NamedWriteableRegi
AlibabaCloudSearchRerankTaskSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AlibabaCloudSearchCompletionServiceSettings.NAME,
AlibabaCloudSearchCompletionServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AlibabaCloudSearchCompletionTaskSettings.NAME,
AlibabaCloudSearchCompletionTaskSettings::new
)
);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
Expand Down Expand Up @@ -50,4 +51,11 @@ public ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String,

return new AlibabaCloudSearchRerankAction(sender, overriddenModel, serviceComponents);
}

@Override
public ExecutableAction create(AlibabaCloudSearchCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = AlibabaCloudSearchCompletionModel.of(model, taskSettings);

return new AlibabaCloudSearchCompletionAction(sender, overriddenModel, serviceComponents);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
Expand All @@ -21,4 +22,6 @@ public interface AlibabaCloudSearchActionVisitor {
ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String, Object> taskSettings);

ExecutableAction create(AlibabaCloudSearchCompletionModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;

import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

public class AlibabaCloudSearchCompletionAction implements ExecutableAction {
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionAction.class);

private final AlibabaCloudSearchAccount account;
private final AlibabaCloudSearchCompletionModel model;
private final String failedToSendRequestErrorMessage;
private final Sender sender;
private final AlibabaCloudSearchCompletionRequestManager requestCreator;

public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompletionModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search completion");
this.requestCreator = AlibabaCloudSearchCompletionRequestManager.of(account, model, serviceComponents.threadPool());
}

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
listener.onFailure(
new ElasticsearchStatusException(
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
RestStatus.INTERNAL_SERVER_ERROR
)
);
return;
}

var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
if (docsOnlyInput.getInputs().size() % 2 == 0) {
listener.onFailure(
new ElasticsearchStatusException(
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
+ "all preceding inputs are the completion history as pairs of user input and the assistant's response.",
RestStatus.BAD_REQUEST
)
);
return;
}

try {
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
failedToSendRequestErrorMessage,
listener
);
sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion.AlibabaCloudSearchCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class AlibabaCloudSearchCompletionRequestManager extends AlibabaCloudSearchRequestManager {
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionRequestManager.class);

private static final ResponseHandler HANDLER = createCompletionHandler();

private static ResponseHandler createCompletionHandler() {
return new AlibabaCloudSearchResponseHandler(
"alibaba cloud search completion",
AlibabaCloudSearchCompletionResponseEntity::fromResponse
);
}

public static AlibabaCloudSearchCompletionRequestManager of(
AlibabaCloudSearchAccount account,
AlibabaCloudSearchCompletionModel model,
ThreadPool threadPool
) {
return new AlibabaCloudSearchCompletionRequestManager(
Objects.requireNonNull(account),
Objects.requireNonNull(model),
Objects.requireNonNull(threadPool)
);
}

private final AlibabaCloudSearchCompletionModel model;

private final AlibabaCloudSearchAccount account;

private AlibabaCloudSearchCompletionRequestManager(
AlibabaCloudSearchAccount account,
AlibabaCloudSearchCompletionModel model,
ThreadPool threadPool
) {
super(threadPool, model);
this.account = Objects.requireNonNull(account);
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> input = DocumentsOnlyInput.of(inferenceInputs).getInputs();
AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ public class AlibabaCloudSearchUtils {
public static final String TEXT_EMBEDDING_PATH = "text-embedding";
public static final String SPARSE_EMBEDDING_PATH = "text-sparse-embedding";
public static final String RERANK_PATH = "ranker";
public static final String COMPLETION_PATH = "text-generation";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class AlibabaCloudSearchCompletionRequest extends AlibabaCloudSearchRequest {
private final AlibabaCloudSearchAccount account;
private final List<String> input;
private final URI uri;
private final AlibabaCloudSearchCompletionTaskSettings taskSettings;
private final String model;
private final String host;
private final String workspaceName;
private final String httpSchema;
private final String inferenceEntityId;

public AlibabaCloudSearchCompletionRequest(
AlibabaCloudSearchAccount account,
List<String> input,
AlibabaCloudSearchCompletionModel completionModel
) {
Objects.requireNonNull(completionModel);

this.account = Objects.requireNonNull(account);
this.input = Objects.requireNonNull(input);
taskSettings = completionModel.getTaskSettings();
model = completionModel.getServiceSettings().getCommonSettings().modelId();
host = completionModel.getServiceSettings().getCommonSettings().getHost();
workspaceName = completionModel.getServiceSettings().getCommonSettings().getWorkspaceName();
httpSchema = completionModel.getServiceSettings().getCommonSettings().getHttpSchema() != null
? completionModel.getServiceSettings().getCommonSettings().getHttpSchema()
: "https";
uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri);
inferenceEntityId = completionModel.getInferenceEntityId();
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(uri);

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new AlibabaCloudSearchCompletionRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(account.apiKey()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public String getInferenceEntityId() {
return inferenceEntityId;
}

@Override
public URI getURI() {
return uri;
}

@Override
public Request truncate() {
return this;
}

@Override
public boolean[] getTruncationInfo() {
return null;
}

URI buildDefaultUri() throws URISyntaxException {
return new URIBuilder().setScheme(httpSchema)
.setHost(host)
.setPathSegments(
AlibabaCloudSearchUtils.VERSION_3,
AlibabaCloudSearchUtils.OPENAPI_PATH,
AlibabaCloudSearchUtils.WORKSPACE_PATH,
workspaceName,
AlibabaCloudSearchUtils.COMPLETION_PATH,
model
)
.build();
}
}
Loading