Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/114457.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 114457
summary: "[Inference API] Introduce Update API to change some aspects of existing\
\ inference endpoints"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;

/**
* This class defines an empty secret settings object. This is useful for services that do not have any secret settings.
Expand Down Expand Up @@ -48,4 +49,9 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {}

@Override
public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
return INSTANCE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Map;

/**
* This class defines an empty task settings object. This is useful for services that do not have any task settings.
Expand Down Expand Up @@ -53,4 +54,9 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
return INSTANCE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;

import java.util.Map;

public interface SecretSettings extends ToXContentObject, VersionedNamedWriteable {

SecretSettings newSecretSettings(Map<String, Object> newSecrets);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;

import java.util.Map;

public interface TaskSettings extends ToXContentObject, VersionedNamedWriteable {

boolean isEmpty();

TaskSettings updatedTaskSettings(Map<String, Object> newSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.ModelConfigurations.SERVICE_SETTINGS;
import static org.elasticsearch.inference.ModelConfigurations.TASK_SETTINGS;

public class UpdateInferenceModelAction extends ActionType<UpdateInferenceModelAction.Response> {

public static final UpdateInferenceModelAction INSTANCE = new UpdateInferenceModelAction();
public static final String NAME = "cluster:admin/xpack/inference/update";

public UpdateInferenceModelAction() {
super(NAME);
}

public record Settings(
@Nullable Map<String, Object> serviceSettings,
@Nullable Map<String, Object> taskSettings,
@Nullable TaskType taskType
) {}

public static class Request extends AcknowledgedRequest<Request> {

private final String inferenceEntityId;
private final BytesReference content;
private final XContentType contentType;
private final TaskType taskType;
private Settings settings;

public Request(String inferenceEntityId, BytesReference content, XContentType contentType, TaskType taskType, TimeValue timeout) {
super(timeout, DEFAULT_ACK_TIMEOUT);
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.taskType = taskType;
}

public Request(StreamInput in) throws IOException {
super(in);
this.inferenceEntityId = in.readString();
this.content = in.readBytesReference();
this.taskType = TaskType.fromStream(in);
this.contentType = in.readEnum(XContentType.class);
}

public String getInferenceEntityId() {
return inferenceEntityId;
}

public TaskType getTaskType() {
return taskType;
}

/**
* The body of the request.
* For in-cluster models, this is expected to contain some of the following:
* "number_of_allocations": `an integer`
*
* For third-party services, this is expected to contain:
* "service_settings": {
* "api_key": `a string` // service settings can only contain an api key
* }
* "task_settings": { a map of settings }
*
*/
public BytesReference getContent() {
return content;
}

/**
* The body of the request as a map.
* The map is validated such that only allowed fields are present.
* If any fields in the body are not on the allow list, this function will throw an exception.
*/
public Settings getContentAsSettings() {
if (settings == null) { // settings is deterministic on content, so we only need to compute it once
Map<String, Object> unvalidatedMap = XContentHelper.convertToMap(content, false, contentType).v2();
Map<String, Object> serviceSettings = new HashMap<>();
Map<String, Object> taskSettings = new HashMap<>();
TaskType taskType = null;

if (unvalidatedMap.isEmpty()) {
throw new ElasticsearchStatusException("Request body is empty", RestStatus.BAD_REQUEST);
}

if (unvalidatedMap.containsKey("task_type")) {
if (unvalidatedMap.get("task_type") instanceof String taskTypeString) {
taskType = TaskType.fromStringOrStatusException(taskTypeString);
} else {
throw new ElasticsearchStatusException(
"Failed to parse [task_type] in update request [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
unvalidatedMap.toString()
);
}
unvalidatedMap.remove("task_type");
}

if (unvalidatedMap.containsKey(SERVICE_SETTINGS)) {
if (unvalidatedMap.get(SERVICE_SETTINGS) instanceof Map<?, ?> tempMap) {
for (Map.Entry<?, ?> entry : (tempMap).entrySet()) {
if (entry.getKey() instanceof String key) {
serviceSettings.put(key, entry.getValue());
} else {
throw new ElasticsearchStatusException(
"Failed to parse update request [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
unvalidatedMap.toString()
);
}
}
unvalidatedMap.remove(SERVICE_SETTINGS);
} else {
throw new ElasticsearchStatusException(
"Unable to parse service settings in the request [{}]",
RestStatus.BAD_REQUEST,
unvalidatedMap.toString()
);
}
}

if (unvalidatedMap.containsKey(TASK_SETTINGS)) {
if (unvalidatedMap.get(TASK_SETTINGS) instanceof Map<?, ?> tempMap) {
for (Map.Entry<?, ?> entry : (tempMap).entrySet()) {
if (entry.getKey() instanceof String key) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to change this line and the similar line above to only use instanceOf on the key

taskSettings.put(key, entry.getValue());
} else {
throw new ElasticsearchStatusException(
"Failed to parse update request [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
unvalidatedMap.toString()
);
}
}
unvalidatedMap.remove(TASK_SETTINGS);
} else {
throw new ElasticsearchStatusException(
"Unable to parse task settings in the request [{}]",
RestStatus.BAD_REQUEST,
unvalidatedMap.toString()
);
}
}

if (unvalidatedMap.isEmpty() == false) {
throw new ElasticsearchStatusException(
"Request contained fields which cannot be updated, remove these fields and try again [{}]",
RestStatus.BAD_REQUEST,
unvalidatedMap.toString()
);
}

this.settings = new Settings(
serviceSettings.isEmpty() == false ? Collections.unmodifiableMap(serviceSettings) : null,
taskSettings.isEmpty() == false ? Collections.unmodifiableMap(taskSettings) : null,
taskType
);
}
return this.settings;
}

public XContentType getContentType() {
return contentType;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
taskType.writeTo(out);
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = new ActionRequestValidationException();
if (MlStrings.isValidId(this.inferenceEntityId) == false) {
validationException.addValidationError(Messages.getMessage(Messages.INVALID_ID, "inference_id", this.inferenceEntityId));
}

if (validationException.validationErrors().isEmpty() == false) {
return validationException;
} else {
return null;
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(content, request.content)
&& contentType == request.contentType
&& taskType == request.taskType;
}

@Override
public int hashCode() {
return Objects.hash(inferenceEntityId, content, contentType, taskType);
}
}

public static class Response extends ActionResponse implements ToXContentObject {

private final ModelConfigurations model;

public Response(ModelConfigurations model) {
this.model = model;
}

public Response(StreamInput in) throws IOException {
super(in);
model = new ModelConfigurations(in);
}

public ModelConfigurations getModel() {
return model;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
model.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return model.toFilteredXContent(builder, params);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response response = (Response) o;
return Objects.equals(model, response.model);
}

@Override
public int hashCode() {
return Objects.hash(model);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ public final class Messages {
public static final String FIELD_CANNOT_BE_NULL = "Field [{0}] cannot be null";
public static final String MODEL_ID_MATCHES_EXISTING_MODEL_IDS_BUT_MUST_NOT =
"Model IDs must be unique. Requested model ID [{}] matches existing model IDs but must not.";
public static final String MODEL_ID_DOES_NOT_MATCH_EXISTING_MODEL_IDS_BUT_MUST_FOR_IN_CLUSTER_SERVICE =
"Requested model ID [{}] does not have a matching trained model and thus cannot be updated.";
public static final String INFERENCE_ENTITY_NON_EXISTANT_NO_UPDATE = "The inference endpoint [{}] does not exist and cannot be updated";

private Messages() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ public static ElasticsearchStatusException badRequestException(String msg, Objec
return new ElasticsearchStatusException(msg, RestStatus.BAD_REQUEST, args);
}

public static ElasticsearchStatusException entityNotFoundException(String msg, Object... args) {
return new ElasticsearchStatusException(msg, RestStatus.NOT_FOUND, args);
}

public static ElasticsearchStatusException taskOperationFailureToStatusException(TaskOperationFailure failure) {
return new ElasticsearchStatusException(failure.getCause().getMessage(), failure.getStatus(), failure.getCause());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@ static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) {
""", taskType);
}

static String updateConfig(@Nullable TaskType taskTypeInBody, String apiKey, int temperature) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
{
%s
"service_settings": {
"api_key": "%s"
},
"task_settings": {
"temperature": %d
}
}
""", taskType, apiKey, temperature);
}

static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) {
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
return Strings.format("""
Expand Down Expand Up @@ -196,6 +211,11 @@ protected Map<String, Object> putModel(String modelId, String modelConfig, TaskT
return putRequest(endpoint, modelConfig);
}

protected Map<String, Object> updateEndpoint(String inferenceID, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s/_update", taskType, inferenceID);
return putRequest(endpoint, modelConfig);
}

protected Map<String, Object> putPipeline(String pipelineId, String modelId) throws IOException {
String endpoint = Strings.format("_ingest/pipeline/%s", pipelineId);
String body = """
Expand Down
Loading