-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add Hugging Face Chat Completion support to Inference Plugin #127254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
63f21de
6b7dd2e
65e4060
404f640
ceebb9a
acaa35b
91fa92e
ff3ef50
965093b
6757b07
58ea9fd
df845eb
cc24e68
5bbe3b7
3684816
7670d2c
6630be7
1efb2ee
61537d0
64c0685
4688901
bfc8072
13ef13b
129caaf
214de5f
d3411d6
e170b96
473dee6
cb03100
c856853
bd2e601
aae528a
82f8049
b0679d5
2fa3dff
cdb3c1c
9370b57
9044bee
e72a312
e2cb334
a4b5d2c
c5988ed
1547559
71c6057
228fffa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.huggingface; | ||
|
|
||
| import org.apache.logging.log4j.LogManager; | ||
| import org.apache.logging.log4j.Logger; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.threadpool.ThreadPool; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; | ||
|
|
||
| import java.util.Objects; | ||
| import java.util.function.Supplier; | ||
|
|
||
| public class HuggingFaceCompletionRequestManager extends HuggingFaceRequestManager { | ||
|
||
| private static final Logger logger = LogManager.getLogger(HuggingFaceCompletionRequestManager.class); | ||
|
|
||
| public static HuggingFaceCompletionRequestManager of( | ||
| HuggingFaceChatCompletionModel model, | ||
| ResponseHandler responseHandler, | ||
| ThreadPool threadPool | ||
| ) { | ||
| return new HuggingFaceCompletionRequestManager( | ||
| Objects.requireNonNull(model), | ||
| Objects.requireNonNull(responseHandler), | ||
| Objects.requireNonNull(threadPool) | ||
| ); | ||
| } | ||
|
|
||
| private final HuggingFaceChatCompletionModel model; | ||
| private final ResponseHandler responseHandler; | ||
|
|
||
| private HuggingFaceCompletionRequestManager( | ||
| HuggingFaceChatCompletionModel model, | ||
| ResponseHandler responseHandler, | ||
| ThreadPool threadPool | ||
| ) { | ||
| super(model, threadPool); | ||
| this.model = model; | ||
| this.responseHandler = responseHandler; | ||
| } | ||
|
|
||
| @Override | ||
| public void execute( | ||
| InferenceInputs inferenceInputs, | ||
| RequestSender requestSender, | ||
| Supplier<Boolean> hasRequestCompletedFunction, | ||
| ActionListener<InferenceServiceResults> listener | ||
| ) { | ||
| var chatCompletionInput = inferenceInputs.castTo(UnifiedChatInput.class); | ||
| HuggingFaceUnifiedChatCompletionRequest request = new HuggingFaceUnifiedChatCompletionRequest(chatCompletionInput, model); | ||
|
|
||
| execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.huggingface; | ||
|
|
||
| import org.apache.logging.log4j.LogManager; | ||
| import org.apache.logging.log4j.Logger; | ||
| import org.elasticsearch.action.ActionListener; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.threadpool.ThreadPool; | ||
| import org.elasticsearch.xpack.inference.common.Truncator; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequest; | ||
|
|
||
| import java.util.List; | ||
| import java.util.Objects; | ||
| import java.util.function.Supplier; | ||
|
|
||
| import static org.elasticsearch.xpack.inference.common.Truncator.truncate; | ||
|
|
||
| public class HuggingFaceEmbeddingsRequestManager extends HuggingFaceRequestManager { | ||
| private static final Logger logger = LogManager.getLogger(HuggingFaceEmbeddingsRequestManager.class); | ||
|
|
||
| public static HuggingFaceEmbeddingsRequestManager of( | ||
| HuggingFaceModel model, | ||
| ResponseHandler responseHandler, | ||
| Truncator truncator, | ||
| ThreadPool threadPool | ||
| ) { | ||
| return new HuggingFaceEmbeddingsRequestManager( | ||
| Objects.requireNonNull(model), | ||
| Objects.requireNonNull(responseHandler), | ||
| Objects.requireNonNull(truncator), | ||
| Objects.requireNonNull(threadPool) | ||
| ); | ||
| } | ||
|
|
||
| private final HuggingFaceModel model; | ||
| private final ResponseHandler responseHandler; | ||
| private final Truncator truncator; | ||
|
|
||
| private HuggingFaceEmbeddingsRequestManager( | ||
| HuggingFaceModel model, | ||
| ResponseHandler responseHandler, | ||
| Truncator truncator, | ||
| ThreadPool threadPool | ||
| ) { | ||
| super(model, threadPool); | ||
| this.model = model; | ||
| this.responseHandler = responseHandler; | ||
| this.truncator = truncator; | ||
| } | ||
|
|
||
| @Override | ||
| public void execute( | ||
| InferenceInputs inferenceInputs, | ||
| RequestSender requestSender, | ||
| Supplier<Boolean> hasRequestCompletedFunction, | ||
| ActionListener<InferenceServiceResults> listener | ||
| ) { | ||
| List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); | ||
| var truncatedInput = truncate(docsInput, model.getTokenLimit()); | ||
| var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); | ||
|
|
||
| execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,7 @@ | |
| import org.elasticsearch.xpack.inference.services.ServiceComponents; | ||
| import org.elasticsearch.xpack.inference.services.ServiceUtils; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; | ||
| import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; | ||
|
|
@@ -51,7 +52,11 @@ public class HuggingFaceService extends HuggingFaceBaseService { | |
| public static final String NAME = "hugging_face"; | ||
|
|
||
| private static final String SERVICE_NAME = "Hugging Face"; | ||
| private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING); | ||
| private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of( | ||
| TaskType.TEXT_EMBEDDING, | ||
| TaskType.SPARSE_EMBEDDING, | ||
| TaskType.COMPLETION | ||
| ); | ||
|
|
||
| public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { | ||
| super(factory, serviceComponents); | ||
|
|
@@ -78,6 +83,14 @@ protected HuggingFaceModel createModel( | |
| context | ||
| ); | ||
| case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); | ||
| case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel( | ||
| inferenceEntityId, | ||
| taskType, | ||
| NAME, | ||
| serviceSettings, | ||
| secretSettings, | ||
| context | ||
| ); | ||
| default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); | ||
| }; | ||
| } | ||
|
|
@@ -149,7 +162,7 @@ public InferenceServiceConfiguration getConfiguration() { | |
|
|
||
| @Override | ||
| public EnumSet<TaskType> supportedTaskTypes() { | ||
| return supportedTaskTypes; | ||
| return SUPPORTED_TASK_TYPES; | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -173,7 +186,7 @@ public static InferenceServiceConfiguration get() { | |
|
|
||
| configurationMap.put( | ||
| URL, | ||
| new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings") | ||
| new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDefaultValue("https://api.openai.com/v1/embeddings") | ||
|
||
| .setDescription("The URL endpoint to use for the requests.") | ||
| .setLabel("URL") | ||
| .setRequired(true) | ||
|
|
@@ -183,12 +196,12 @@ public static InferenceServiceConfiguration get() { | |
| .build() | ||
| ); | ||
|
|
||
| configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); | ||
| configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); | ||
| configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); | ||
| configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); | ||
|
|
||
| return new InferenceServiceConfiguration.Builder().setService(NAME) | ||
| .setName(SERVICE_NAME) | ||
| .setTaskTypes(supportedTaskTypes) | ||
| .setTaskTypes(SUPPORTED_TASK_TYPES) | ||
| .setConfigurations(configurationMap) | ||
| .build(); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,13 @@ | |
| import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.Sender; | ||
| import org.elasticsearch.xpack.inference.services.ServiceComponents; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceCompletionRequestManager; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceChatCompletionResponseEntity; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; | ||
| import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; | ||
|
|
||
|
|
@@ -26,6 +29,9 @@ | |
| * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type. | ||
| */ | ||
| public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { | ||
|
|
||
| private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = | ||
| "Failed to send Hugging Face %s request from inference entity id [%s]"; | ||
| private final Sender sender; | ||
| private final ServiceComponents serviceComponents; | ||
|
|
||
|
|
@@ -40,34 +46,38 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) { | |
| "hugging face text embeddings", | ||
| HuggingFaceEmbeddingsResponseEntity::fromResponse | ||
| ); | ||
| var requestCreator = HuggingFaceRequestManager.of( | ||
| var requestCreator = HuggingFaceEmbeddingsRequestManager.of( | ||
| model, | ||
| responseHandler, | ||
| serviceComponents.truncator(), | ||
| serviceComponents.threadPool() | ||
| ); | ||
| var errorMessage = format( | ||
| "Failed to send Hugging Face %s request from inference entity id [%s]", | ||
| "text embeddings", | ||
| model.getInferenceEntityId() | ||
| ); | ||
| var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "text embeddings", model.getInferenceEntityId()); | ||
|
||
| return new SenderExecutableAction(sender, requestCreator, errorMessage); | ||
| } | ||
|
|
||
| @Override | ||
| public ExecutableAction create(HuggingFaceElserModel model) { | ||
| var responseHandler = new HuggingFaceResponseHandler("hugging face elser", HuggingFaceElserResponseEntity::fromResponse); | ||
| var requestCreator = HuggingFaceRequestManager.of( | ||
| var requestCreator = HuggingFaceEmbeddingsRequestManager.of( | ||
| model, | ||
| responseHandler, | ||
| serviceComponents.truncator(), | ||
| serviceComponents.threadPool() | ||
| ); | ||
| var errorMessage = format( | ||
| "Failed to send Hugging Face %s request from inference entity id [%s]", | ||
| "ELSER", | ||
| model.getInferenceEntityId() | ||
| var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId()); | ||
| return new SenderExecutableAction(sender, requestCreator, errorMessage); | ||
| } | ||
|
|
||
| @Override | ||
| public ExecutableAction create(HuggingFaceChatCompletionModel model) { | ||
| var responseHandler = new HuggingFaceResponseHandler( | ||
| "hugging face chat completion", | ||
| HuggingFaceChatCompletionResponseEntity::fromResponse | ||
| ); | ||
|
|
||
| var requestCreator = HuggingFaceCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool()); | ||
| var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId()); | ||
| return new SenderExecutableAction(sender, requestCreator, errorMessage); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked through the entire PR but just wanted to check. We should try to add the chat completion functionality to the existing
HuggingFaceServicelogic.For example the OpenAI service supports many task types: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java#L175-L197
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change is done. Now completion logic is in single HuggingFaceService class.