Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/143081.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
area: Inference
issues: []
pr: 143081
summary: "[Inference API] Parse endpoint metadata from persisted endpoints"
type: enhancement
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,17 @@ default List<String> aliases() {
*/
void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener);

default Model parsePersistedConfigWithSecrets(UnparsedModel unparsedModel) {
return parsePersistedConfigWithSecrets(
unparsedModel.inferenceEntityId(),
unparsedModel.taskType(),
unparsedModel.settings(),
unparsedModel.secrets()
);
}

/**
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. This requires that
* secrets and service settings be in two separate maps.
* Parse model from an {@link UnparsedModel} and return the fully parsed {@link Model}.
* This function modifies {@code config map}, fields are removed from the map as they are read.
* <p>
* If the map contains unrecognized configuration option an
* {@code ElasticsearchStatusException} is thrown.
*
* If the map contains unrecognized configuration options, no error is thrown.
*
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @param secrets Sensitive configuration options (e.g. api key)
* @return The parsed {@link Model}
* @param unparsedModel the unparsed model
* @return the fully parsed model
*/
Model parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets);
Model parsePersistedConfig(UnparsedModel unparsedModel);

/**
* Create a new model from {@link ModelConfigurations} and {@link ModelSecrets} objects.
Expand All @@ -83,23 +71,6 @@ default Model parsePersistedConfigWithSecrets(UnparsedModel unparsedModel) {
*/
Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets);

/**
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}.
* This function modifies {@code config map}, fields are removed from the map as they are read.
*
* If the map contains unrecognized configuration options, no error is thrown.
*
* @param modelId Model Id
* @param taskType The model task type
* @param config Configuration options
* @return The parsed {@link Model}
*/
Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config);

default Model parsePersistedConfig(UnparsedModel unparsedModel) {
return parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
}

InferenceServiceConfiguration getConfiguration();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.inference.metadata.EndpointMetadata;

import java.util.HashMap;
import java.util.Map;

/**
Expand All @@ -34,4 +35,25 @@ public UnparsedModel(
) {
this(inferenceEntityId, taskType, service, settings, secrets, EndpointMetadata.EMPTY_INSTANCE);
}

public UnparsedModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> settings,
Map<String, Object> secrets,
EndpointMetadata endpointMetadata
) {
this.inferenceEntityId = inferenceEntityId;
this.taskType = taskType;
this.service = service;

// We ensure that settings and secrets maps are modifiable because during parsing we are removing from them
this.settings = settings == null ? null : new HashMap<>(settings);
// Additionally, an empty secrets map is treated as null in order to skip potential validations for missing keys
// which should not be necessary when parsing a persisted model.
this.secrets = secrets == null || secrets.isEmpty() ? null : new HashMap<>(secrets);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you talk a little more about the advantage of using null here? It was possible to get null before right? We're just changing empty to also be null?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If secrets here ends being an empty map, validations will fail complaining for missing fields. In my understanding, we should never have validations failing when parsing a persisted config. They should only apply when we parse from the request. Thus, here, I'm taking care of this potential issue in a single place.

Previously this was not an issue because there was the variant for parsing without secrets which resulted in null. As now we have a single parse method, it is good defense I think to handle this here. Happy to revise though if you think otherwise.


this.endpointMetadata = endpointMetadata;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

import org.elasticsearch.test.ESTestCase;

import java.util.Map;

import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;

public class UnparsedModelTests extends ESTestCase {

public void testNullSecrets() {
UnparsedModel model = new UnparsedModel("id", randomFrom(TaskType.values()), "test_service", Map.of(), null);
assertThat(model.secrets(), is(nullValue()));
}

public void testEmptySecrets_SetToNull() {
UnparsedModel model = new UnparsedModel("id", randomFrom(TaskType.values()), "test_service", Map.of(), Map.of());
assertThat(model.secrets(), is(nullValue()));
}

public void testSettingsIsModifiable_GivenUnmodifiableMap() {
UnparsedModel model = new UnparsedModel("id", randomFrom(TaskType.values()), "test_service", Map.of("key", "value"), Map.of());
model.settings().remove("key");
assertThat(model.settings().isEmpty(), is(true));
}

public void testSecretsIsModifiable_GivenUnmodifiableMap() {
UnparsedModel model = new UnparsedModel("id", randomFrom(TaskType.values()), "test_service", Map.of(), Map.of("key", "value"));
model.secrets().remove("key");
assertThat(model.secrets().isEmpty(), is(true));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.chunking.NoopChunker;
import org.elasticsearch.xpack.core.inference.chunking.WordBoundaryChunker;
Expand Down Expand Up @@ -75,42 +76,28 @@ protected static Map<String, Object> getTaskSettingsMap(Map<String, Object> sett

@Override
@SuppressWarnings("unchecked")
public TestServiceModel parsePersistedConfigWithSecrets(
String modelId,
TaskType taskType,
Map<String, Object> config,
Map<String, Object> secrets
) {
public TestServiceModel parsePersistedConfig(UnparsedModel unparsedModel) {
var config = unparsedModel.settings();
var secrets = unparsedModel.secrets();
var taskType = unparsedModel.taskType();

var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
var secretSettingsMap = (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);
var secretSettingsMap = secrets == null ? null : (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);

var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);

var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);

return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
return new TestServiceModel(unparsedModel.inferenceEntityId(), taskType, name(), serviceSettings, taskSettings, secretSettings);
}

@Override
public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) {
return new TestServiceModel(config, secrets);
}

@Override
@SuppressWarnings("unchecked")
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);

var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);

return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
}

protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
return TestTaskSettings.fromMap(taskSettingsMap);
}
Expand Down Expand Up @@ -250,6 +237,10 @@ public record TestSecretSettings(String apiKey) implements SecretSettings {
static final String NAME = "test_secret_settings";

public static TestSecretSettings fromMap(Map<String, Object> map) {
if (map == null) {
return null;
}

ValidationException validationException = new ValidationException();

String apiKey = (String) map.remove("api_key");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -170,12 +169,7 @@ public void testGetModel() throws Exception {
);

// When we parse the persisted config, if the chunking settings were null they will be defaulted to OLD_DEFAULT_SETTINGS
ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets(
modelHolder.get().inferenceEntityId(),
modelHolder.get().taskType(),
modelHolder.get().settings(),
modelHolder.get().secrets()
);
ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfig(modelHolder.get());

assertElserModelsEqual(roundTripModel, model);
}
Expand Down Expand Up @@ -308,7 +302,7 @@ public void testGetModelsByTaskType() throws InterruptedException {
.collect(Collectors.toSet());
modelHolder.get().forEach(m -> {
assertTrue(sparseIds.contains(m.inferenceEntityId()));
assertThat(m.secrets().keySet(), empty());
assertThat(m.secrets(), is(nullValue()));
});

blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder);
Expand All @@ -319,7 +313,7 @@ public void testGetModelsByTaskType() throws InterruptedException {
.collect(Collectors.toSet());
modelHolder.get().forEach(m -> {
assertTrue(denseIds.contains(m.inferenceEntityId()));
assertThat(m.secrets().keySet(), empty());
assertThat(m.secrets(), is(nullValue()));
});
}

Expand All @@ -328,7 +322,6 @@ public void testGetAllModels() throws InterruptedException {
var createdModels = new ArrayList<Model>();
int modelCount = randomIntBetween(30, 100);

AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

for (int i = 0; i < modelCount; i++) {
Expand All @@ -349,7 +342,7 @@ public void testGetAllModels() throws InterruptedException {
assertEquals(createdModels.get(i).getInferenceEntityId(), getAllModels.get(i).inferenceEntityId());
assertEquals(createdModels.get(i).getTaskType(), getAllModels.get(i).taskType());
assertEquals(createdModels.get(i).getConfigurations().getService(), getAllModels.get(i).service());
assertThat(getAllModels.get(i).secrets().keySet(), empty());
assertThat(getAllModels.get(i).secrets(), is(nullValue()));
}
}

Expand All @@ -372,7 +365,7 @@ public void testGetModelWithSecrets() throws InterruptedException {

// get model without secrets
blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder);
assertThat(modelHolder.get().secrets().keySet(), empty());
assertThat(modelHolder.get().secrets(), is(nullValue()));
assertReturnModelIsModifiable(modelHolder.get());
}

Expand Down Expand Up @@ -1093,7 +1086,7 @@ public void testGetModelNoSecrets() {
assertEquals("foo", modelConfig.service());
assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType());
assertNotNull(modelConfig.settings().keySet());
assertThat(modelConfig.secrets().keySet(), empty());
assertThat(modelConfig.secrets(), is(nullValue()));
}

public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ private void doExecuteForked(
Model model;
if (service.isPresent()) {
try {
model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
model = service.get().parsePersistedConfig(unparsedModel);
} catch (Exception e) {
if (request.isForceDelete()) {
listener.onResponse(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ private void getSingleModel(
return;
}

var model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
var model = service.get().parsePersistedConfig(unparsedModel);

service.get()
.updateModelsWithDynamicFields(
Expand Down Expand Up @@ -142,10 +141,7 @@ private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetI
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
}
var list = parsedModelsByService.computeIfAbsent(service.get().name(), s -> new ArrayList<>());
list.add(
service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
);
list.add(service.get().parsePersistedConfig(unparsedModel));
}

var groupedListener = new GroupedActionListener<List<Model>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ protected void doExecute(
}

if (service.get() instanceof RerankingInferenceService rerankingInferenceService) {
var model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
var model = service.get().parsePersistedConfig(unparsedModel);

l.onResponse(
new GetRerankerWindowSizeAction.Response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -142,13 +141,7 @@ protected void masterOperation(
})
.<Boolean>andThen((listener, existingUnparsedModel) -> {

Model existingParsedModel = service.get()
.parsePersistedConfigWithSecrets(
existingUnparsedModel.inferenceEntityId(),
existingUnparsedModel.taskType(),
new HashMap<>(existingUnparsedModel.settings()),
new HashMap<>(existingUnparsedModel.secrets())
);
Model existingParsedModel = service.get().parsePersistedConfig(existingUnparsedModel);

validateResolvedTaskType(existingParsedModel, resolvedTaskType);

Expand Down Expand Up @@ -191,11 +184,7 @@ protected void masterOperation(
)
);
} else {
listener.onResponse(
service.get()
.parsePersistedConfig(inferenceEntityId, resolvedTaskType, new HashMap<>(unparsedModel.settings()))
.getConfigurations()
);
listener.onResponse(service.get().parsePersistedConfig(unparsedModel).getConfigurations());
}
}, listener::onFailure));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,7 @@ private void executeChunkedInferenceAsync(
ActionListener<UnparsedModel> modelLoadingListener = ActionListener.wrap(unparsedModel -> {
var service = inferenceServiceRegistry.getService(unparsedModel.service());
if (service.isEmpty() == false) {
var provider = new InferenceProvider(
service.get(),
service.get()
.parsePersistedConfigWithSecrets(
inferenceId,
unparsedModel.taskType(),
unparsedModel.settings(),
unparsedModel.secrets()
)
);
var provider = new InferenceProvider(service.get(), service.get().parsePersistedConfig(unparsedModel));
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
} else {
try (onFinish) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,7 @@ private void loadFromIndex(InferenceIdAndProject idAndProject, ActionListener<Mo
)
);

var model = service.parsePersistedConfigWithSecrets(
unparsedModel.inferenceEntityId(),
unparsedModel.taskType(),
unparsedModel.settings(),
unparsedModel.secrets()
);
var model = service.parsePersistedConfig(unparsedModel);

if (cacheEnabled()) {
cache.put(idAndProject, model);
Expand Down
Loading