Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
63f21de
Add Hugging Face Chat Completion support to Inference Plugin
Jan-Kazlouski-elastic Apr 23, 2025
6b7dd2e
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic Apr 25, 2025
65e4060
Add support for streaming chat completion task for HuggingFace
Jan-Kazlouski-elastic Apr 25, 2025
404f640
[CI] Auto commit changes from spotless
Apr 25, 2025
ceebb9a
Add support for non-streaming completion task for HuggingFace
Jan-Kazlouski-elastic Apr 25, 2025
acaa35b
Remove RequestManager for HF Chat Completion Task
Jan-Kazlouski-elastic Apr 25, 2025
91fa92e
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic Apr 28, 2025
ff3ef50
Refactored Hugging Face Completion Service Settings, removed Request …
Jan-Kazlouski-elastic Apr 28, 2025
965093b
Refactored Hugging Face Action Creator, added Unit Tests
Jan-Kazlouski-elastic Apr 29, 2025
6757b07
Add Hugging Face Server Test
Jan-Kazlouski-elastic Apr 29, 2025
58ea9fd
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic Apr 29, 2025
df845eb
[CI] Auto commit changes from spotless
Apr 29, 2025
cc24e68
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 2, 2025
5bbe3b7
Removed parameters from media type for Chat Completion Request and un…
Jan-Kazlouski-elastic May 2, 2025
3684816
Removed OpenAI default URL in HuggingFaceService's configuration, fix…
Jan-Kazlouski-elastic May 2, 2025
7670d2c
Refactor error message handling in HuggingFaceActionCreator and Huggi…
Jan-Kazlouski-elastic May 2, 2025
6630be7
Update minimal supported version and add Hugging Face transport versi…
Jan-Kazlouski-elastic May 2, 2025
1efb2ee
Made modelId field optional in HuggingFaceChatCompletionModel, update…
Jan-Kazlouski-elastic May 2, 2025
61537d0
Removed max input tokens field from HuggingFaceChatCompletionServiceS…
Jan-Kazlouski-elastic May 2, 2025
64c0685
Removed if statement checking TransportVersion for HuggingFaceChatCom…
Jan-Kazlouski-elastic May 2, 2025
4688901
Removed getFirst() method calls for backport compatibility
Jan-Kazlouski-elastic May 2, 2025
bfc8072
Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWC…
Jan-Kazlouski-elastic May 2, 2025
13ef13b
Refactored tests to use stripWhitespace method for readability
Jan-Kazlouski-elastic May 2, 2025
129caaf
Refactored javadoc for HuggingFaceService
Jan-Kazlouski-elastic May 2, 2025
214de5f
Renamed HF chat completion TransportVersion constant names
Jan-Kazlouski-elastic May 2, 2025
d3411d6
Added random string generation in unit test
Jan-Kazlouski-elastic May 2, 2025
e170b96
Refactored javadocs for HuggingFace requests
Jan-Kazlouski-elastic May 2, 2025
473dee6
Refactored tests to reduce duplication
Jan-Kazlouski-elastic May 2, 2025
cb03100
Added changelog file
Jan-Kazlouski-elastic May 2, 2025
c856853
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 5, 2025
bd2e601
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic May 5, 2025
aae528a
Add HuggingFaceChatCompletionResponseHandler and associated tests
Jan-Kazlouski-elastic May 5, 2025
82f8049
Refactor error handling in HuggingFaceServiceTests to standardize err…
Jan-Kazlouski-elastic May 5, 2025
b0679d5
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 6, 2025
2fa3dff
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 7, 2025
cdb3c1c
Refactor HuggingFace error handling to improve response structure and…
Jan-Kazlouski-elastic May 7, 2025
9370b57
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 11, 2025
9044bee
Allowing null function name for hugging face models
jonathan-buttner May 9, 2025
e72a312
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 12, 2025
e2cb334
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 13, 2025
a4b5d2c
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 13, 2025
c5988ed
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 14, 2025
1547559
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 19, 2025
71c6057
Merge branch 'main' into feature/hugging-face-chat-completion-integra…
Jan-Kazlouski-elastic May 19, 2025
228fffa
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
Expand Down Expand Up @@ -353,6 +354,13 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceChatCompletionServiceSettings.NAME,
HuggingFaceChatCompletionServiceSettings::new
)
);
}

private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionService;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
Expand Down Expand Up @@ -361,6 +362,7 @@ public void loadExtensions(ExtensionLoader loader) {
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceChatCompletionService(httpFactory.get(), serviceComponents.get()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked through the entire PR but just wanted to check. We should try to add the chat completion functionality to the existing HuggingFaceService logic.

For example the OpenAI service supports many task types: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java#L175-L197

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change is done. Now completion logic is in single HuggingFaceService class.

context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.huggingface;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;

import java.util.Objects;
import java.util.function.Supplier;

public class HuggingFaceCompletionRequestManager extends HuggingFaceRequestManager {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to move away from the request manager pattern because it adds duplicate code. Could you look into following the pattern we started here (we haven't refactored all the services yet but if it's possible to do for hugging face it'd be great if we could do it now)?

#124144

One option would be to leave the other hugging face request managers as they are (if possible, it may not be though) and then use one of the generic request managers like shown in the PR above for this new functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing. I will adopt the approach from the shared PR. Thanks Jonathan!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did the change that allowed us to move away from request manager for chat_completion and completion tasks.

private static final Logger logger = LogManager.getLogger(HuggingFaceCompletionRequestManager.class);

public static HuggingFaceCompletionRequestManager of(
HuggingFaceChatCompletionModel model,
ResponseHandler responseHandler,
ThreadPool threadPool
) {
return new HuggingFaceCompletionRequestManager(
Objects.requireNonNull(model),
Objects.requireNonNull(responseHandler),
Objects.requireNonNull(threadPool)
);
}

private final HuggingFaceChatCompletionModel model;
private final ResponseHandler responseHandler;

private HuggingFaceCompletionRequestManager(
HuggingFaceChatCompletionModel model,
ResponseHandler responseHandler,
ThreadPool threadPool
) {
super(model, threadPool);
this.model = model;
this.responseHandler = responseHandler;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var chatCompletionInput = inferenceInputs.castTo(UnifiedChatInput.class);
HuggingFaceUnifiedChatCompletionRequest request = new HuggingFaceUnifiedChatCompletionRequest(chatCompletionInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.huggingface;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequest;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.common.Truncator.truncate;

public class HuggingFaceEmbeddingsRequestManager extends HuggingFaceRequestManager {
private static final Logger logger = LogManager.getLogger(HuggingFaceEmbeddingsRequestManager.class);

public static HuggingFaceEmbeddingsRequestManager of(
HuggingFaceModel model,
ResponseHandler responseHandler,
Truncator truncator,
ThreadPool threadPool
) {
return new HuggingFaceEmbeddingsRequestManager(
Objects.requireNonNull(model),
Objects.requireNonNull(responseHandler),
Objects.requireNonNull(truncator),
Objects.requireNonNull(threadPool)
);
}

private final HuggingFaceModel model;
private final ResponseHandler responseHandler;
private final Truncator truncator;

private HuggingFaceEmbeddingsRequestManager(
HuggingFaceModel model,
ResponseHandler responseHandler,
Truncator truncator,
ThreadPool threadPool
) {
super(model, threadPool);
this.model = model;
this.responseHandler = responseHandler;
this.truncator = truncator;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
var truncatedInput = truncate(docsInput, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,66 +7,12 @@

package org.elasticsearch.xpack.inference.services.huggingface;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.common.Truncator.truncate;

public class HuggingFaceRequestManager extends BaseRequestManager {
private static final Logger logger = LogManager.getLogger(HuggingFaceRequestManager.class);

public static HuggingFaceRequestManager of(
HuggingFaceModel model,
ResponseHandler responseHandler,
Truncator truncator,
ThreadPool threadPool
) {
return new HuggingFaceRequestManager(
Objects.requireNonNull(model),
Objects.requireNonNull(responseHandler),
Objects.requireNonNull(truncator),
Objects.requireNonNull(threadPool)
);
}

private final HuggingFaceModel model;
private final ResponseHandler responseHandler;
private final Truncator truncator;

private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler responseHandler, Truncator truncator, ThreadPool threadPool) {
public abstract class HuggingFaceRequestManager extends BaseRequestManager {
protected HuggingFaceRequestManager(HuggingFaceModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
this.model = model;
this.responseHandler = responseHandler;
this.truncator = truncator;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
var truncatedInput = truncate(docsInput, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}

record RateLimitGrouping(int accountHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
Expand All @@ -51,7 +52,11 @@ public class HuggingFaceService extends HuggingFaceBaseService {
public static final String NAME = "hugging_face";

private static final String SERVICE_NAME = "Hugging Face";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING);
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.SPARSE_EMBEDDING,
TaskType.COMPLETION
);

public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
Expand All @@ -78,6 +83,14 @@ protected HuggingFaceModel createModel(
context
);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -149,7 +162,7 @@ public InferenceServiceConfiguration getConfiguration() {

@Override
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
return SUPPORTED_TASK_TYPES;
}

@Override
Expand All @@ -173,7 +186,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
URL,
new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings")
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDefaultValue("https://api.openai.com/v1/embeddings")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops looks like we have an existing bug here (unrelated to your changes). Can you remove the setDefaultValue that shouldn't be pointing to openai 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially assumed it is there for some internal configuration and didn't want to introduce any risks by changing it. Removed.

.setDescription("The URL endpoint to use for the requests.")
.setLabel("URL")
.setRequired(true)
Expand All @@ -183,12 +196,12 @@ public static InferenceServiceConfiguration get() {
.build()
);

configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes));
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(supportedTaskTypes)
.setTaskTypes(SUPPORTED_TASK_TYPES)
.setConfigurations(configurationMap)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceCompletionRequestManager;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;

Expand All @@ -26,6 +29,9 @@
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type.
*/
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {

private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
"Failed to send Hugging Face %s request from inference entity id [%s]";
private final Sender sender;
private final ServiceComponents serviceComponents;

Expand All @@ -40,34 +46,38 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
"hugging face text embeddings",
HuggingFaceEmbeddingsResponseEntity::fromResponse
);
var requestCreator = HuggingFaceRequestManager.of(
var requestCreator = HuggingFaceEmbeddingsRequestManager.of(
model,
responseHandler,
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = format(
"Failed to send Hugging Face %s request from inference entity id [%s]",
"text embeddings",
model.getInferenceEntityId()
);
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "text embeddings", model.getInferenceEntityId());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Same comment as above suggesting making this a function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the change described in my comment above.

return new SenderExecutableAction(sender, requestCreator, errorMessage);
}

@Override
public ExecutableAction create(HuggingFaceElserModel model) {
var responseHandler = new HuggingFaceResponseHandler("hugging face elser", HuggingFaceElserResponseEntity::fromResponse);
var requestCreator = HuggingFaceRequestManager.of(
var requestCreator = HuggingFaceEmbeddingsRequestManager.of(
model,
responseHandler,
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = format(
"Failed to send Hugging Face %s request from inference entity id [%s]",
"ELSER",
model.getInferenceEntityId()
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId());
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}

@Override
public ExecutableAction create(HuggingFaceChatCompletionModel model) {
var responseHandler = new HuggingFaceResponseHandler(
"hugging face chat completion",
HuggingFaceChatCompletionResponseEntity::fromResponse
);

var requestCreator = HuggingFaceCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool());
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId());
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}
}
Loading