Skip to content

Commit 65e4060

Browse files
Add support for streaming chat completion task for HuggingFace
1 parent 6b7dd2e commit 65e4060

File tree

13 files changed

+127
-258
lines changed

13 files changed

+127
-258
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@
127127
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
128128
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
129129
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
130-
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionService;
131130
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
132131
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
133132
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
@@ -362,7 +361,6 @@ public void loadExtensions(ExtensionLoader loader) {
362361
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
363362
return List.of(
364363
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
365-
context -> new HuggingFaceChatCompletionService(httpFactory.get(), serviceComponents.get()),
366364
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
367365
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
368366
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,22 @@
2323
import java.util.Objects;
2424
import java.util.function.Supplier;
2525

26-
public class HuggingFaceCompletionRequestManager extends HuggingFaceRequestManager {
27-
private static final Logger logger = LogManager.getLogger(HuggingFaceCompletionRequestManager.class);
26+
/**
27+
* Manages the execution of chat completion requests for Hugging Face models.
28+
* <p>
29+
* This class is responsible for creating and executing requests to Hugging Face's chat completion API.
30+
* It extends {@link HuggingFaceRequestManager} to provide specific functionality for chat completion models.
31+
* </p>
32+
*/
33+
public class HuggingFaceChatCompletionRequestManager extends HuggingFaceRequestManager {
34+
private static final Logger logger = LogManager.getLogger(HuggingFaceChatCompletionRequestManager.class);
2835

29-
public static HuggingFaceCompletionRequestManager of(
36+
public static HuggingFaceChatCompletionRequestManager of(
3037
HuggingFaceChatCompletionModel model,
3138
ResponseHandler responseHandler,
3239
ThreadPool threadPool
3340
) {
34-
return new HuggingFaceCompletionRequestManager(
41+
return new HuggingFaceChatCompletionRequestManager(
3542
Objects.requireNonNull(model),
3643
Objects.requireNonNull(responseHandler),
3744
Objects.requireNonNull(threadPool)
@@ -41,7 +48,7 @@ public static HuggingFaceCompletionRequestManager of(
4148
private final HuggingFaceChatCompletionModel model;
4249
private final ResponseHandler responseHandler;
4350

44-
private HuggingFaceCompletionRequestManager(
51+
private HuggingFaceChatCompletionRequestManager(
4552
HuggingFaceChatCompletionModel model,
4653
ResponseHandler responseHandler,
4754
ThreadPool threadPool

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,30 @@
1818
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
1919
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
2020
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
21-
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequest;
21+
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
2222

2323
import java.util.List;
2424
import java.util.Objects;
2525
import java.util.function.Supplier;
2626

2727
import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
2828

29+
/**
30+
* This class is responsible for managing requests to the Hugging Face API for generating embeddings.
31+
* It handles the execution of requests, including truncation of input data and response handling.
32+
*/
2933
public class HuggingFaceEmbeddingsRequestManager extends HuggingFaceRequestManager {
3034
private static final Logger logger = LogManager.getLogger(HuggingFaceEmbeddingsRequestManager.class);
3135

36+
/**
37+
* Creates a new instance of HuggingFaceEmbeddingsRequestManager.
38+
*
39+
* @param model The Hugging Face model to be used for generating embeddings.
40+
* @param responseHandler The response handler for processing the API responses.
41+
* @param truncator The truncator for handling input data truncation.
42+
* @param threadPool The thread pool for executing requests.
43+
* @return A new instance of HuggingFaceEmbeddingsRequestManager.
44+
*/
3245
public static HuggingFaceEmbeddingsRequestManager of(
3346
HuggingFaceModel model,
3447
ResponseHandler responseHandler,
@@ -68,7 +81,7 @@ public void execute(
6881
) {
6982
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
7083
var truncatedInput = truncate(docsInput, model.getTokenLimit());
71-
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
84+
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);
7285

7386
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
7487
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,24 @@
4343
import java.util.HashMap;
4444
import java.util.List;
4545
import java.util.Map;
46+
import java.util.Set;
4647

4748
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
4849
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
49-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
5050

51+
/**
52+
* This class is responsible for managing the Hugging Face inference service.
53+
* It handles the creation of models, chunked inference, and unified completion inference.
54+
*/
5155
public class HuggingFaceService extends HuggingFaceBaseService {
5256
public static final String NAME = "hugging_face";
5357

5458
private static final String SERVICE_NAME = "Hugging Face";
5559
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
5660
TaskType.TEXT_EMBEDDING,
5761
TaskType.SPARSE_EMBEDDING,
58-
TaskType.COMPLETION
62+
TaskType.COMPLETION,
63+
TaskType.CHAT_COMPLETION
5964
);
6065

6166
public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
@@ -152,7 +157,21 @@ protected void doUnifiedCompletionInfer(
152157
TimeValue timeout,
153158
ActionListener<InferenceServiceResults> listener
154159
) {
155-
throwUnsupportedUnifiedCompletionOperation(NAME);
160+
if (model instanceof HuggingFaceChatCompletionModel == false) {
161+
listener.onFailure(createInvalidModelException(model));
162+
return;
163+
}
164+
HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model;
165+
var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());
166+
var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest());
167+
var action = overriddenModel.accept(actionCreator);
168+
169+
action.execute(inputs, timeout, listener);
170+
}
171+
172+
@Override
173+
public Set<TaskType> supportedStreamingTasks() {
174+
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
156175
}
157176

158177
@Override
@@ -180,6 +199,9 @@ public static InferenceServiceConfiguration get() {
180199
return configuration.getOrCompute();
181200
}
182201

202+
private Configuration() {
203+
}
204+
183205
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
184206
() -> {
185207
var configurationMap = new HashMap<String, SettingsConfiguration>();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1212
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1313
import org.elasticsearch.xpack.inference.services.ServiceComponents;
14-
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceCompletionRequestManager;
14+
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionRequestManager;
1515
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager;
1616
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
1717
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
1818
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
1919
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
20-
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceChatCompletionResponseEntity;
2120
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
2221
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
22+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
23+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
2324

2425
import java.util.Objects;
2526

@@ -71,12 +72,12 @@ public ExecutableAction create(HuggingFaceElserModel model) {
7172

7273
@Override
7374
public ExecutableAction create(HuggingFaceChatCompletionModel model) {
74-
var responseHandler = new HuggingFaceResponseHandler(
75+
var responseHandler = new OpenAiUnifiedChatCompletionResponseHandler(
7576
"hugging face chat completion",
76-
HuggingFaceChatCompletionResponseEntity::fromResponse
77+
OpenAiChatCompletionResponseEntity::fromResponse
7778
);
7879

79-
var requestCreator = HuggingFaceCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool());
80+
var requestCreator = HuggingFaceChatCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool());
8081
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId());
8182
return new SenderExecutableAction(sender, requestCreator, errorMessage);
8283
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java

Lines changed: 0 additions & 167 deletions
This file was deleted.

0 commit comments

Comments
 (0)