Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public enum TaskType implements Writeable {
public boolean isAnyOrSame(TaskType other) {
return true;
}
};
},
CHAT_COMPLETION;

public static String NAME = "task_type";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(9));

var providers = new ArrayList<String>();
for (int i = 0; i < services.size(); i++) {
Expand All @@ -259,7 +259,6 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"elastic",
"googleaistudio",
"openai",
"streaming_completion_test_service"
Expand All @@ -269,6 +268,32 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
assertThat(providers, containsInAnyOrder(providerList.toArray()));
}

@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
assertThat(services.size(), equalTo(2));
} else {
assertThat(services.size(), equalTo(1));
}

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

var providerList = new ArrayList<>(List.of("openai"));

if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
providerList.add(0, "elastic");
}

assertArrayEquals(providers, providerList.toArray());
}

@SuppressWarnings("unchecked")
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ public final class Paths {
+ "}/{"
+ INFERENCE_ID
+ "}/_stream";
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_unified";

public static final String UNIFIED_SUFFIX = "_unified";
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX;
static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{"
+ TASK_TYPE_OR_INFERENCE_ID
+ "}/{"
+ INFERENCE_ID
+ "}/_unified";
+ "}/"
+ UNIFIED_SUFFIX;

private Paths() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void infer(

private static InferenceInputs createInput(Model model, List<String> input, @Nullable String query, boolean stream) {
return switch (model.getTaskType()) {
case COMPLETION -> new ChatCompletionInput(input, stream);
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
case RERANK -> new QueryAndDocsInputs(query, input, stream);
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream);
default -> throw new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.ENABLED;
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS;
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS;
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_SUFFIX;
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;

public final class ServiceUtils {
Expand Down Expand Up @@ -780,5 +781,24 @@ public static void throwUnsupportedUnifiedCompletionOperation(String serviceName
throw new UnsupportedOperationException(Strings.format("The %s service does not support unified completion", serviceName));
}

public static String unsupportedTaskTypeForInference(Model model, EnumSet<TaskType> supportedTaskTypes) {
return Strings.format(
"Inference entity [%s] does not support task type [%s] for inference, the task type must be one of %s.",
model.getInferenceEntityId(),
model.getTaskType(),
supportedTaskTypes
);
}

public static String useChatCompletionUrlMessage(Model model) {
return org.elasticsearch.common.Strings.format(
"The task type for the inference entity is %s, please use the _inference/%s/%s/%s URL.",
model.getTaskType(),
model.getTaskType(),
model.getInferenceEntityId(),
UNIFIED_SUFFIX
);
}

private ServiceUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand All @@ -41,6 +42,7 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
Expand All @@ -61,6 +63,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;

public class ElasticInferenceService extends SenderService {

Expand All @@ -69,8 +72,16 @@ public class ElasticInferenceService extends SenderService {

private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION);
// The task types exposed via the _inference/_services API
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.CHAT_COMPLETION
);
private static final String SERVICE_NAME = "Elastic";
/**
* The task types that the {@link InferenceAction.Request} can accept.
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);

public ElasticInferenceService(
HttpRequestSender.Factory factory,
Expand All @@ -83,7 +94,7 @@ public ElasticInferenceService(

@Override
public Set<TaskType> supportedStreamingTasks() {
return COMPLETION_ONLY;
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY);
}

@Override
Expand Down Expand Up @@ -129,6 +140,15 @@ protected void doInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) {
var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES);

if (model.getTaskType() == TaskType.CHAT_COMPLETION) {
responseString = responseString + " " + useChatCompletionUrlMessage(model);
}
listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST));
}

if (model instanceof ElasticInferenceServiceExecutableActionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
Expand Down Expand Up @@ -207,7 +227,7 @@ public InferenceServiceConfiguration getConfiguration() {

@Override
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
}

private static ElasticInferenceServiceModel createModel(
Expand Down Expand Up @@ -375,7 +395,7 @@ public static InferenceServiceConfiguration get() {

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(supportedTaskTypes)
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
.setConfigurations(configurationMap)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
Expand Down Expand Up @@ -63,14 +64,24 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.ORGANIZATION;

public class OpenAiService extends SenderService {
public static final String NAME = "openai";

private static final String SERVICE_NAME = "OpenAI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
// The task types exposed via the _inference/_services API
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
/**
* The task types that the {@link InferenceAction.Request} can accept.
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);

public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
Expand Down Expand Up @@ -164,7 +175,7 @@ private static OpenAiModel createModel(
secretSettings,
context
);
case COMPLETION -> new OpenAiChatCompletionModel(
case COMPLETION, CHAT_COMPLETION -> new OpenAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
Expand Down Expand Up @@ -236,7 +247,7 @@ public InferenceServiceConfiguration getConfiguration() {

@Override
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
}

@Override
Expand All @@ -248,6 +259,15 @@ public void doInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) {
var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES);

if (model.getTaskType() == TaskType.CHAT_COMPLETION) {
responseString = responseString + " " + useChatCompletionUrlMessage(model);
}
listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST));
}

if (model instanceof OpenAiModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
Expand Down Expand Up @@ -356,7 +376,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public Set<TaskType> supportedStreamingTasks() {
return COMPLETION_ONLY;
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION, TaskType.ANY);
}

/**
Expand Down Expand Up @@ -444,7 +464,7 @@ public static InferenceServiceConfiguration get() {

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(supportedTaskTypes)
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
.setConfigurations(configurationMap)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,24 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
}
}

public static Model getInvalidModel(String inferenceEntityId, String serviceName) {
public static Model getInvalidModel(String inferenceEntityId, String serviceName, TaskType taskType) {
var mockConfigs = mock(ModelConfigurations.class);
when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId);
when(mockConfigs.getService()).thenReturn(serviceName);
when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
when(mockConfigs.getTaskType()).thenReturn(taskType);

var mockModel = mock(Model.class);
when(mockModel.getInferenceEntityId()).thenReturn(inferenceEntityId);
when(mockModel.getConfigurations()).thenReturn(mockConfigs);
when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
when(mockModel.getTaskType()).thenReturn(taskType);

return mockModel;
}

public static Model getInvalidModel(String inferenceEntityId, String serviceName) {
return getInvalidModel(inferenceEntityId, serviceName, TaskType.TEXT_EMBEDDING);
}

public static SimilarityMeasure randomSimilarityMeasure() {
return randomFrom(SimilarityMeasure.values());
}
Expand Down
Loading