diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index fcb8ea7213795..42a6b45bb2caf 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -29,7 +29,8 @@ public enum TaskType implements Writeable { public boolean isAnyOrSame(TaskType other) { return true; } - }; + }, + CHAT_COMPLETION; public static String NAME = "task_type"; diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 83500c5604e67..53bea0660c58b 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -243,7 +243,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(10)); + assertThat(services.size(), equalTo(9)); var providers = new ArrayList(); for (int i = 0; i < services.size(); i++) { @@ -259,7 +259,6 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", - "elastic", "googleaistudio", "openai", "streaming_completion_test_service" @@ -269,6 +268,32 @@ public void testGetServicesWithCompletionTaskType() throws IOException { assertThat(providers, containsInAnyOrder(providerList.toArray())); } + @SuppressWarnings("unchecked") + public void testGetServicesWithChatCompletionTaskType() throws IOException { + List services = getServices(TaskType.CHAT_COMPLETION); + if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled() + || ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) { + assertThat(services.size(), equalTo(2)); + } else { + assertThat(services.size(), equalTo(1)); + } + + String[] providers = new String[services.size()]; + for (int i = 0; i < services.size(); i++) { + Map serviceConfig = (Map) services.get(i); + providers[i] = (String) serviceConfig.get("service"); + } + + var providerList = new ArrayList<>(List.of("openai")); + + if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled() + || ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) { + providerList.add(0, "elastic"); + } + + assertArrayEquals(providers, providerList.toArray()); + } + @SuppressWarnings("unchecked") public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { List services = getServices(TaskType.SPARSE_EMBEDDING); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index c46f211bb26af..57c06df8d8dfe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -30,12 +30,15 @@ public final class Paths { + "}/{" + INFERENCE_ID + "}/_stream"; - static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_unified"; + + public static final String UNIFIED_SUFFIX = "_unified"; + static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX; static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/{" + INFERENCE_ID - + "}/_unified"; + + "}/" + + UNIFIED_SUFFIX; private Paths() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 208744b40ce9d..7c28df4cc0dc4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -73,7 +73,7 @@ public void infer( private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { return switch (model.getTaskType()) { - case COMPLETION -> new ChatCompletionInput(input, stream); + case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); case RERANK -> new QueryAndDocsInputs(query, input, stream); case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream); default -> throw new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 7d05bac363fb1..1ddae3cc8df95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.ENABLED; import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_SUFFIX; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; public final class ServiceUtils { @@ -780,5 +781,24 @@ public static void throwUnsupportedUnifiedCompletionOperation(String serviceName throw new UnsupportedOperationException(Strings.format("The %s service does not support unified completion", serviceName)); } + public static String unsupportedTaskTypeForInference(Model model, EnumSet supportedTaskTypes) { + return Strings.format( + "Inference entity [%s] does not support task type [%s] for inference, the task type must be one of %s.", + model.getInferenceEntityId(), + model.getTaskType(), + supportedTaskTypes + ); + } + + public static String useChatCompletionUrlMessage(Model model) { + return org.elasticsearch.common.Strings.format( + "The task type for the inference entity is %s, please use the _inference/%s/%s/%s URL.", + model.getTaskType(), + model.getTaskType(), + model.getInferenceEntityId(), + UNIFIED_SUFFIX + ); + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index b87675b961c0d..0cbc360797e97 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -27,6 +27,7 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; @@ -41,6 +42,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -61,6 +63,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage; public class ElasticInferenceService extends SenderService { @@ -69,8 +72,16 @@ public class ElasticInferenceService extends SenderService { private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); + // The task types exposed via the _inference/_services API + private static final EnumSet SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of( + TaskType.SPARSE_EMBEDDING, + TaskType.CHAT_COMPLETION + ); private static final String SERVICE_NAME = "Elastic"; + /** + * The task types that the {@link InferenceAction.Request} can accept. + */ + private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING); public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -83,7 +94,7 @@ public ElasticInferenceService( @Override public Set supportedStreamingTasks() { - return COMPLETION_ONLY; + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY); } @Override @@ -129,6 +140,15 @@ protected void doInfer( TimeValue timeout, ActionListener listener ) { + if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) { + var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES); + + if (model.getTaskType() == TaskType.CHAT_COMPLETION) { + responseString = responseString + " " + useChatCompletionUrlMessage(model); + } + listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST)); + } + if (model instanceof ElasticInferenceServiceExecutableActionModel == false) { listener.onFailure(createInvalidModelException(model)); return; @@ -207,7 +227,7 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return supportedTaskTypes; + return SUPPORTED_TASK_TYPES_FOR_SERVICES_API; } private static ElasticInferenceServiceModel createModel( @@ -375,7 +395,7 @@ public static InferenceServiceConfiguration get() { return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) - .setTaskTypes(supportedTaskTypes) + .setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API) .setConfigurations(configurationMap) .build(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index ba9dea8ace8ee..3efd7c44c3e97 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -27,6 +27,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; @@ -63,6 +64,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.ORGANIZATION; @@ -70,7 +72,16 @@ public class OpenAiService extends SenderService { public static final String NAME = "openai"; private static final String SERVICE_NAME = "OpenAI"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); + // The task types exposed via the _inference/_services API + private static final EnumSet SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION + ); + /** + * The task types that the {@link InferenceAction.Request} can accept. + */ + private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); @@ -164,7 +175,7 @@ private static OpenAiModel createModel( secretSettings, context ); - case COMPLETION -> new OpenAiChatCompletionModel( + case COMPLETION, CHAT_COMPLETION -> new OpenAiChatCompletionModel( inferenceEntityId, taskType, NAME, @@ -236,7 +247,7 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return supportedTaskTypes; + return SUPPORTED_TASK_TYPES_FOR_SERVICES_API; } @Override @@ -248,6 +259,15 @@ public void doInfer( TimeValue timeout, ActionListener listener ) { + if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) { + var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES); + + if (model.getTaskType() == TaskType.CHAT_COMPLETION) { + responseString = responseString + " " + useChatCompletionUrlMessage(model); + } + listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST)); + } + if (model instanceof OpenAiModel == false) { listener.onFailure(createInvalidModelException(model)); return; @@ -356,7 +376,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public Set supportedStreamingTasks() { - return COMPLETION_ONLY; + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION, TaskType.ANY); } /** @@ -444,7 +464,7 @@ public static InferenceServiceConfiguration get() { return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) - .setTaskTypes(supportedTaskTypes) + .setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API) .setConfigurations(configurationMap) .build(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index e02ac7b8853ad..5bb70db4a1e8d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -163,19 +163,24 @@ public List getInferenceServiceFactories() { } } - public static Model getInvalidModel(String inferenceEntityId, String serviceName) { + public static Model getInvalidModel(String inferenceEntityId, String serviceName, TaskType taskType) { var mockConfigs = mock(ModelConfigurations.class); when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId); when(mockConfigs.getService()).thenReturn(serviceName); - when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + when(mockConfigs.getTaskType()).thenReturn(taskType); var mockModel = mock(Model.class); + when(mockModel.getInferenceEntityId()).thenReturn(inferenceEntityId); when(mockModel.getConfigurations()).thenReturn(mockConfigs); - when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + when(mockModel.getTaskType()).thenReturn(taskType); return mockModel; } + public static Model getInvalidModel(String inferenceEntityId, String serviceName) { + return getInvalidModel(inferenceEntityId, serviceName, TaskType.TEXT_EMBEDDING); + } + public static SimilarityMeasure randomSimilarityMeasure() { return randomFrom(SimilarityMeasure.values()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index c9b46d130e5ee..1458c85a8bc85 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -319,7 +319,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - var mockModel = getInvalidModel("model_id", "service_name"); + var mockModel = getInvalidModel("model_id", "service_name", TaskType.SPARSE_EMBEDDING); try ( var service = new ElasticInferenceService( @@ -355,6 +355,98 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException verifyNoMoreInteractions(sender); } + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING); + + try ( + var service = new ElasticInferenceService( + factory, + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(null) + ) + ) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is( + "Inference entity [model_id] does not support task type [text_embedding] " + + "for inference, the task type must be one of [sparse_embedding]." + ) + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); + + try ( + var service = new ElasticInferenceService( + factory, + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(null) + ) + ) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is( + "Inference entity [model_id] does not support task type [chat_completion] " + + "for inference, the task type must be one of [sparse_embedding]. " + + "The task type for the inference entity is chat_completion, " + + "please use the _inference/chat_completion/model_id/_unified URL." + ) + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var eisGatewayUrl = getUrl(webServer); @@ -481,7 +573,7 @@ public void testGetConfiguration() throws Exception { { "service": "elastic", "name": "Elastic", - "task_types": ["sparse_embedding" , "completion"], + "task_types": ["sparse_embedding", "chat_completion"], "configurations": { "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 9313969cbc15a..e935e2e188e37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -864,6 +864,86 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException verifyNoMoreInteractions(sender); } + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name", TaskType.SPARSE_EMBEDDING); + + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is( + "Inference entity [model_id] does not support task type [sparse_embedding] " + + "for inference, the task type must be one of [text_embedding, completion]." + ) + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); + + try (var service = new OpenAiService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is( + "Inference entity [model_id] does not support task type [chat_completion] " + + "for inference, the task type must be one of [text_embedding, completion]. " + + "The task type for the inference entity is chat_completion, " + + "please use the _inference/chat_completion/model_id/_unified URL." + ) + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + public void testInfer_SendsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -1660,7 +1740,7 @@ public void testGetConfiguration() throws Exception { { "service": "openai", "name": "OpenAI", - "task_types": ["text_embedding", "completion"], + "task_types": ["text_embedding", "completion", "chat_completion"], "configurations": { "api_key": { "description": "The OpenAI API authentication key. For more details about generating OpenAI API keys, refer to the https://platform.openai.com/account/api-keys.",