Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/112512.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 112512
summary: Add Completion Inference API for Alibaba Cloud AI Search Model
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankServiceSettings;
Expand Down Expand Up @@ -543,6 +545,20 @@ private static void addAlibabaCloudSearchNamedWriteables(List<NamedWriteableRegi
AlibabaCloudSearchRerankTaskSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AlibabaCloudSearchCompletionServiceSettings.NAME,
AlibabaCloudSearchCompletionServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AlibabaCloudSearchCompletionTaskSettings.NAME,
AlibabaCloudSearchCompletionTaskSettings::new
)
);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
Expand Down Expand Up @@ -50,4 +51,11 @@ public ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String,

return new AlibabaCloudSearchRerankAction(sender, overriddenModel, serviceComponents);
}

@Override
public ExecutableAction create(AlibabaCloudSearchCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = AlibabaCloudSearchCompletionModel.of(model, taskSettings);

return new AlibabaCloudSearchCompletionAction(sender, overriddenModel, serviceComponents);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
Expand All @@ -21,4 +22,6 @@ public interface AlibabaCloudSearchActionVisitor {
ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String, Object> taskSettings);

ExecutableAction create(AlibabaCloudSearchCompletionModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;

import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

public class AlibabaCloudSearchCompletionAction implements ExecutableAction {
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionAction.class);

private final AlibabaCloudSearchAccount account;
private final AlibabaCloudSearchCompletionModel model;
private final String failedToSendRequestErrorMessage;
private final Sender sender;
private final AlibabaCloudSearchCompletionRequestManager requestCreator;

public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompletionModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(null, "AlibabaCloud Search completion");
this.requestCreator = AlibabaCloudSearchCompletionRequestManager.of(account, model, serviceComponents.threadPool());
}

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
listener.onFailure(
new ElasticsearchStatusException(
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
RestStatus.INTERNAL_SERVER_ERROR
)
);
return;
}

var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
if (docsOnlyInput.getInputs().size() % 2 == 0) {
listener.onFailure(
new ElasticsearchStatusException(
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
+ "all preceding inputs are the completion history as pairs of user input and the assistant's response.",
RestStatus.BAD_REQUEST
)
);
return;
}

try {
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
failedToSendRequestErrorMessage,
listener
);
sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion.AlibabaCloudSearchCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class AlibabaCloudSearchCompletionRequestManager extends AlibabaCloudSearchRequestManager {
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionRequestManager.class);

private static final ResponseHandler HANDLER = createCompletionHandler();

private static ResponseHandler createCompletionHandler() {
return new AlibabaCloudSearchResponseHandler(
"alibaba cloud search completion",
AlibabaCloudSearchCompletionResponseEntity::fromResponse
);
}

public static AlibabaCloudSearchCompletionRequestManager of(
AlibabaCloudSearchAccount account,
AlibabaCloudSearchCompletionModel model,
ThreadPool threadPool
) {
return new AlibabaCloudSearchCompletionRequestManager(
Objects.requireNonNull(account),
Objects.requireNonNull(model),
Objects.requireNonNull(threadPool)
);
}

private final AlibabaCloudSearchCompletionModel model;

private final AlibabaCloudSearchAccount account;

private AlibabaCloudSearchCompletionRequestManager(
AlibabaCloudSearchAccount account,
AlibabaCloudSearchCompletionModel model,
ThreadPool threadPool
) {
super(threadPool, model);
this.account = Objects.requireNonNull(account);
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> input = DocumentsOnlyInput.of(inferenceInputs).getInputs();
AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ public class AlibabaCloudSearchUtils {
public static final String TEXT_EMBEDDING_PATH = "text-embedding";
public static final String SPARSE_EMBEDDING_PATH = "text-sparse-embedding";
public static final String RERANK_PATH = "ranker";
public static final String COMPLETION_PATH = "text-generation";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.completion;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchRequest;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class AlibabaCloudSearchCompletionRequest extends AlibabaCloudSearchRequest {
private final AlibabaCloudSearchAccount account;
private final List<String> input;
private final URI uri;
private final AlibabaCloudSearchCompletionTaskSettings taskSettings;
private final String model;
private final String host;
private final String workspaceName;
private final String httpSchema;
private final String inferenceEntityId;

public AlibabaCloudSearchCompletionRequest(
AlibabaCloudSearchAccount account,
List<String> input,
AlibabaCloudSearchCompletionModel completionModel
) {
Objects.requireNonNull(completionModel);

this.account = Objects.requireNonNull(account);
this.input = Objects.requireNonNull(input);
taskSettings = completionModel.getTaskSettings();
model = completionModel.getServiceSettings().getCommonSettings().modelId();
host = completionModel.getServiceSettings().getCommonSettings().getHost();
workspaceName = completionModel.getServiceSettings().getCommonSettings().getWorkspaceName();
httpSchema = completionModel.getServiceSettings().getCommonSettings().getHttpSchema() != null
? completionModel.getServiceSettings().getCommonSettings().getHttpSchema()
: "https";
uri = buildUri(null, AlibabaCloudSearchUtils.SERVICE_NAME, this::buildDefaultUri);
inferenceEntityId = completionModel.getInferenceEntityId();
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(uri);

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new AlibabaCloudSearchCompletionRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(account.apiKey()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public String getInferenceEntityId() {
return inferenceEntityId;
}

@Override
public URI getURI() {
return uri;
}

@Override
public Request truncate() {
return this;
}

@Override
public boolean[] getTruncationInfo() {
return null;
}

URI buildDefaultUri() throws URISyntaxException {
return new URIBuilder().setScheme(httpSchema)
.setHost(host)
.setPathSegments(
AlibabaCloudSearchUtils.VERSION_3,
AlibabaCloudSearchUtils.OPENAPI_PATH,
AlibabaCloudSearchUtils.WORKSPACE_PATH,
workspaceName,
AlibabaCloudSearchUtils.COMPLETION_PATH,
model
)
.build();
}
}
Loading