diff --git a/docs/changelog/143081.yaml b/docs/changelog/143081.yaml new file mode 100644 index 0000000000000..7a688681fd27a --- /dev/null +++ b/docs/changelog/143081.yaml @@ -0,0 +1,5 @@ +area: Inference +issues: [] +pr: 143081 +summary: "[Inference API] Parse endpoint metadata from persisted endpoints" +type: enhancement diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 7b1d1c0423035..54267f07558c0 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -50,29 +50,17 @@ default List aliases() { */ void parseRequestConfig(String modelId, TaskType taskType, Map config, ActionListener parsedModelListener); - default Model parsePersistedConfigWithSecrets(UnparsedModel unparsedModel) { - return parsePersistedConfigWithSecrets( - unparsedModel.inferenceEntityId(), - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ); - } - /** - * Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. This requires that - * secrets and service settings be in two separate maps. + * Parse model from an {@link UnparsedModel} and return the fully parsed {@link Model}. * This function modifies {@code config map}, fields are removed from the map as they are read. + *

+ * If the map contains unrecognized configuration option an + * {@code ElasticsearchStatusException} is thrown. * - * If the map contains unrecognized configuration options, no error is thrown. - * - * @param modelId Model Id - * @param taskType The model task type - * @param config Configuration options - * @param secrets Sensitive configuration options (e.g. api key) - * @return The parsed {@link Model} + * @param unparsedModel the unparsed model + * @return the fully parsed model */ - Model parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map config, Map secrets); + Model parsePersistedConfig(UnparsedModel unparsedModel); /** * Create a new model from {@link ModelConfigurations} and {@link ModelSecrets} objects. @@ -83,23 +71,6 @@ default Model parsePersistedConfigWithSecrets(UnparsedModel unparsedModel) { */ Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets); - /** - * Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. - * This function modifies {@code config map}, fields are removed from the map as they are read. - * - * If the map contains unrecognized configuration options, no error is thrown. - * - * @param modelId Model Id - * @param taskType The model task type - * @param config Configuration options - * @return The parsed {@link Model} - */ - Model parsePersistedConfig(String modelId, TaskType taskType, Map config); - - default Model parsePersistedConfig(UnparsedModel unparsedModel) { - return parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); - } - InferenceServiceConfiguration getConfiguration(); /** diff --git a/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java b/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java index cf8f49ccbd89c..8d01e04fcde63 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java +++ b/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.metadata.EndpointMetadata; +import java.util.HashMap; import java.util.Map; /** @@ -34,4 +35,25 @@ public UnparsedModel( ) { this(inferenceEntityId, taskType, service, settings, secrets, EndpointMetadata.EMPTY_INSTANCE); } + + public UnparsedModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map settings, + Map secrets, + EndpointMetadata endpointMetadata + ) { + this.inferenceEntityId = inferenceEntityId; + this.taskType = taskType; + this.service = service; + + // We ensure that settings and secrets maps are modifiable because during parsing we are removing from them + this.settings = settings == null ? null : new HashMap<>(settings); + // Additionally, an empty secrets map is treated as null in order to skip potential validations for missing keys + // which should not be necessary when parsing a persisted model. + this.secrets = secrets == null || secrets.isEmpty() ? null : new HashMap<>(secrets); + + this.endpointMetadata = endpointMetadata; + } } diff --git a/server/src/test/java/org/elasticsearch/inference/UnparsedModelTests.java b/server/src/test/java/org/elasticsearch/inference/UnparsedModelTests.java new file mode 100644 index 0000000000000..5a8468b76906a --- /dev/null +++ b/server/src/test/java/org/elasticsearch/inference/UnparsedModelTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Map; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class UnparsedModelTests extends ESTestCase { + + public void testNullSecrets() { + UnparsedModel model = new UnparsedModel("id", randomFrom(TaskType.values()), "test_service", Map.of(), null); + assertThat(model.secrets(), is(nullValue())); + } + + public void testEmptySecrets_SetToNull() { + UnparsedModel model = new UnparsedModel("id", randomFrom(TaskType.values()), "test_service", Map.of(), Map.of()); + assertThat(model.secrets(), is(nullValue())); + } + + public void testSettingsIsModifiable_GivenUnmodifiableMap() { + UnparsedModel model = new UnparsedModel("id", randomFrom(TaskType.values()), "test_service", Map.of("key", "value"), Map.of()); + model.settings().remove("key"); + assertThat(model.settings().isEmpty(), is(true)); + } + + public void testSecretsIsModifiable_GivenUnmodifiableMap() { + UnparsedModel model = new UnparsedModel("id", randomFrom(TaskType.values()), "test_service", Map.of(), Map.of("key", "value")); + model.secrets().remove("key"); + assertThat(model.secrets().isEmpty(), is(true)); + } +} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index ffc745e130d8a..502f1d71f459a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -25,6 +25,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.chunking.NoopChunker; import org.elasticsearch.xpack.core.inference.chunking.WordBoundaryChunker; @@ -75,14 +76,13 @@ protected static Map getTaskSettingsMap(Map sett @Override @SuppressWarnings("unchecked") - public TestServiceModel parsePersistedConfigWithSecrets( - String modelId, - TaskType taskType, - Map config, - Map secrets - ) { + public TestServiceModel parsePersistedConfig(UnparsedModel unparsedModel) { + var config = unparsedModel.settings(); + var secrets = unparsedModel.secrets(); + var taskType = unparsedModel.taskType(); + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); - var secretSettingsMap = (Map) secrets.remove(ModelSecrets.SECRET_SETTINGS); + var secretSettingsMap = secrets == null ? null : (Map) secrets.remove(ModelSecrets.SECRET_SETTINGS); var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap); var secretSettings = TestSecretSettings.fromMap(secretSettingsMap); @@ -90,7 +90,7 @@ public TestServiceModel parsePersistedConfigWithSecrets( var taskSettingsMap = getTaskSettingsMap(config); var taskSettings = getTasksSettingsFromMap(taskSettingsMap); - return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); + return new TestServiceModel(unparsedModel.inferenceEntityId(), taskType, name(), serviceSettings, taskSettings, secretSettings); } @Override @@ -98,19 +98,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec return new TestServiceModel(config, secrets); } - @Override - @SuppressWarnings("unchecked") - public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { - var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); - - var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap); - - var taskSettingsMap = getTaskSettingsMap(config); - var taskSettings = getTasksSettingsFromMap(taskSettingsMap); - - return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null); - } - protected TaskSettings getTasksSettingsFromMap(Map taskSettingsMap) { return TestTaskSettings.fromMap(taskSettingsMap); } @@ -250,6 +237,10 @@ public record TestSecretSettings(String apiKey) implements SecretSettings { static final String NAME = "test_secret_settings"; public static TestSecretSettings fromMap(Map map) { + if (map == null) { + return null; + } + ValidationException validationException = new ValidationException(); String apiKey = (String) map.remove("api_key"); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index dddb38cd06ae7..4fcf0648eb0da 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -92,7 +92,6 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -170,12 +169,7 @@ public void testGetModel() throws Exception { ); // When we parse the persisted config, if the chunking settings were null they will be defaulted to OLD_DEFAULT_SETTINGS - ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets( - modelHolder.get().inferenceEntityId(), - modelHolder.get().taskType(), - modelHolder.get().settings(), - modelHolder.get().secrets() - ); + ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfig(modelHolder.get()); assertElserModelsEqual(roundTripModel, model); } @@ -308,7 +302,7 @@ public void testGetModelsByTaskType() throws InterruptedException { .collect(Collectors.toSet()); modelHolder.get().forEach(m -> { assertTrue(sparseIds.contains(m.inferenceEntityId())); - assertThat(m.secrets().keySet(), empty()); + assertThat(m.secrets(), is(nullValue())); }); blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); @@ -319,7 +313,7 @@ public void testGetModelsByTaskType() throws InterruptedException { .collect(Collectors.toSet()); modelHolder.get().forEach(m -> { assertTrue(denseIds.contains(m.inferenceEntityId())); - assertThat(m.secrets().keySet(), empty()); + assertThat(m.secrets(), is(nullValue())); }); } @@ -328,7 +322,6 @@ public void testGetAllModels() throws InterruptedException { var createdModels = new ArrayList(); int modelCount = randomIntBetween(30, 100); - AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); for (int i = 0; i < modelCount; i++) { @@ -349,7 +342,7 @@ public void testGetAllModels() throws InterruptedException { assertEquals(createdModels.get(i).getInferenceEntityId(), getAllModels.get(i).inferenceEntityId()); assertEquals(createdModels.get(i).getTaskType(), getAllModels.get(i).taskType()); assertEquals(createdModels.get(i).getConfigurations().getService(), getAllModels.get(i).service()); - assertThat(getAllModels.get(i).secrets().keySet(), empty()); + assertThat(getAllModels.get(i).secrets(), is(nullValue())); } } @@ -372,7 +365,7 @@ public void testGetModelWithSecrets() throws InterruptedException { // get model without secrets blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); - assertThat(modelHolder.get().secrets().keySet(), empty()); + assertThat(modelHolder.get().secrets(), is(nullValue())); assertReturnModelIsModifiable(modelHolder.get()); } @@ -1093,7 +1086,7 @@ public void testGetModelNoSecrets() { assertEquals("foo", modelConfig.service()); assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); assertNotNull(modelConfig.settings().keySet()); - assertThat(modelConfig.secrets().keySet(), empty()); + assertThat(modelConfig.secrets(), is(nullValue())); } public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 23ef7481166d9..0f598472ea477 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -132,8 +132,7 @@ private void doExecuteForked( Model model; if (service.isPresent()) { try { - model = service.get() - .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + model = service.get().parsePersistedConfig(unparsedModel); } catch (Exception e) { if (request.isForceDelete()) { listener.onResponse(true); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index 0dbbc7715fbfa..16e0118c7f185 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -97,8 +97,7 @@ private void getSingleModel( return; } - var model = service.get() - .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + var model = service.get().parsePersistedConfig(unparsedModel); service.get() .updateModelsWithDynamicFields( @@ -142,10 +141,7 @@ private void parseModels(List unparsedModels, ActionListener new ArrayList<>()); - list.add( - service.get() - .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()) - ); + list.add(service.get().parsePersistedConfig(unparsedModel)); } var groupedListener = new GroupedActionListener>( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetRerankerWindowSizeAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetRerankerWindowSizeAction.java index 8e0a0d6696167..1bdb491132616 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetRerankerWindowSizeAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetRerankerWindowSizeAction.java @@ -80,8 +80,7 @@ protected void doExecute( } if (service.get() instanceof RerankingInferenceService rerankingInferenceService) { - var model = service.get() - .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + var model = service.get().parsePersistedConfig(unparsedModel); l.onResponse( new GetRerankerWindowSizeAction.Response( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java index 8c8c31c5cea42..34a2f7fd6a3d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java @@ -53,7 +53,6 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -142,13 +141,7 @@ protected void masterOperation( }) .andThen((listener, existingUnparsedModel) -> { - Model existingParsedModel = service.get() - .parsePersistedConfigWithSecrets( - existingUnparsedModel.inferenceEntityId(), - existingUnparsedModel.taskType(), - new HashMap<>(existingUnparsedModel.settings()), - new HashMap<>(existingUnparsedModel.secrets()) - ); + Model existingParsedModel = service.get().parsePersistedConfig(existingUnparsedModel); validateResolvedTaskType(existingParsedModel, resolvedTaskType); @@ -191,11 +184,7 @@ protected void masterOperation( ) ); } else { - listener.onResponse( - service.get() - .parsePersistedConfig(inferenceEntityId, resolvedTaskType, new HashMap<>(unparsedModel.settings())) - .getConfigurations() - ); + listener.onResponse(service.get().parsePersistedConfig(unparsedModel).getConfigurations()); } }, listener::onFailure)); } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 17c2ad161ac6e..0ddeb18efc371 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -337,16 +337,7 @@ private void executeChunkedInferenceAsync( ActionListener modelLoadingListener = ActionListener.wrap(unparsedModel -> { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { - var provider = new InferenceProvider( - service.get(), - service.get() - .parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ) - ); + var provider = new InferenceProvider(service.get(), service.get().parsePersistedConfig(unparsedModel)); executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); } else { try (onFinish) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java index e9d1705f50d13..381e6aac7371a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java @@ -126,12 +126,7 @@ private void loadFromIndex(InferenceIdAndProject idAndProject, ActionListener implements InferenceService { + protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION); + + /** + * The task types that support chunking settings + */ + protected static final EnumSet CHUNKING_TASK_TYPES = EnumSet.of(SPARSE_EMBEDDING, TEXT_EMBEDDING, EMBEDDING); + private final Sender sender; private final ServiceComponents serviceComponents; private final ClusterService clusterService; + private final Map> modelCreators; - public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + public SenderService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ClusterService clusterService, + Map> modelCreators + ) { Objects.requireNonNull(factory); sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); this.clusterService = Objects.requireNonNull(clusterService); + this.modelCreators = Objects.requireNonNull(modelCreators); } public Sender getSender() { @@ -88,8 +113,47 @@ public void infer( }).addListener(listener); } + public M parsePersistedConfig(UnparsedModel unparsedModel) { + var config = unparsedModel.settings(); + var secrets = unparsedModel.secrets(); + var taskType = unparsedModel.taskType(); + + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = secrets == null ? null : removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (CHUNKING_TASK_TYPES.contains(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + migrateBetweenTaskAndServiceSettings(serviceSettingsMap, taskSettingsMap); + + return retrieveModelCreatorFromMapOrThrow( + modelCreators, + unparsedModel.inferenceEntityId(), + taskType, + name(), + ConfigurationParseContext.PERSISTENT + ).createFromMaps( + unparsedModel.inferenceEntityId(), + taskType, + name(), + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + ConfigurationParseContext.PERSISTENT + ); + } + + /** + * Allows for implementations to perform migration for the cases where settings were moved between service and task settings. + */ + protected void migrateBetweenTaskAndServiceSettings(Map serviceSettings, Map taskSettings) {} + private static InferenceInputs createInput( - SenderService service, + SenderService service, Model model, List input, InputType inputType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java index be84fb015eea9..96eb031d9f5b8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java @@ -64,7 +64,7 @@ * using AI21 models. It supports completion and chat completion tasks. * The service uses Ai21ActionCreator to create actions for executing inference requests. */ -public class Ai21Service extends SenderService { +public class Ai21Service extends SenderService { public static final String NAME = "ai21"; private static final String SERVICE_NAME = "AI21"; @@ -94,7 +94,7 @@ public Ai21Service( } public Ai21Service(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -202,20 +202,6 @@ public void parseRequestConfig( } } - @Override - public Ai21Model parsePersistedConfigWithSecrets( - String modelId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - return createModelFromPersistent(modelId, taskType, serviceSettingsMap, secretSettingsMap); - } - @Override public Ai21Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -227,14 +213,6 @@ public Ai21Model buildModelFromConfigAndSecrets(ModelConfigurations config, Mode ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null); - } - @Override public TransportVersion getMinimalSupportedVersion() { return ML_INFERENCE_AI21_COMPLETION_ADDED; @@ -264,15 +242,6 @@ private static Ai21Model createModel( ); } - private Ai21Model createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map secretSettings - ) { - return createModel(inferenceEntityId, taskType, serviceSettings, secretSettings, ConfigurationParseContext.PERSISTENT); - } - /** * Configuration class for the AI21 inference service. * It provides the settings and configurations required for the service. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 5772176e4c0e6..f8d41c7b6613f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -58,7 +58,6 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -69,7 +68,7 @@ import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.SERVICE_ID; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.WORKSPACE_NAME; -public class AlibabaCloudSearchService extends SenderService implements RerankingInferenceService { +public class AlibabaCloudSearchService extends SenderService implements RerankingInferenceService { public static final String NAME = AlibabaCloudSearchUtils.SERVICE_NAME; private static final String SERVICE_NAME = "AlibabaCloud AI Search"; @@ -110,7 +109,7 @@ public AlibabaCloudSearchService( ServiceComponents serviceComponents, ClusterService clusterService ) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -166,25 +165,6 @@ public EnumSet supportedTaskTypes() { return SUPPORTED_TASK_TYPES; } - private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - private static AlibabaCloudSearchModel createModel( String inferenceEntityId, TaskType taskType, @@ -206,32 +186,6 @@ private static AlibabaCloudSearchModel createModel( ); } - @Override - public AlibabaCloudSearchModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelWithoutLoggingDeprecations( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public AlibabaCloudSearchModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -243,26 +197,6 @@ public AlibabaCloudSearchModel buildModelFromConfigAndSecrets(ModelConfiguration ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelWithoutLoggingDeprecations( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); - } - @Override protected void doUnifiedCompletionInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index ba7f481b92f3c..255616b1b78c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -63,7 +63,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUnsupportedTaskTypeStatusException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -80,7 +79,7 @@ * * https://github.com/elastic/ml-team/issues/1706 */ -public class AmazonBedrockService extends SenderService { +public class AmazonBedrockService extends SenderService { public static final String NAME = "amazonbedrock"; private static final String SERVICE_NAME = "Amazon Bedrock"; public static final String CHAT_COMPLETION_ERROR_PREFIX = "Amazon Bedrock chat completion"; @@ -132,7 +131,7 @@ public AmazonBedrockService( ServiceComponents serviceComponents, ClusterService clusterService ) { - super(httpSenderFactory, serviceComponents, clusterService); + super(httpSenderFactory, serviceComponents, clusterService, MODEL_CREATORS); this.amazonBedrockSender = amazonBedrockFactory.createSender(); } @@ -260,33 +259,6 @@ public void parseRequestConfig( } } - @Override - public Model parsePersistedConfigWithSecrets( - String modelId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModel( - modelId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap, - ConfigurationParseContext.PERSISTENT - ); - } - @Override public AmazonBedrockModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { var model = retrieveModelCreatorFromMapOrThrow( @@ -300,27 +272,6 @@ public AmazonBedrockModel buildModelFromConfigAndSecrets(ModelConfigurations con return model; } - @Override - public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModel( - modelId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - ConfigurationParseContext.PERSISTENT - ); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index fc931324de202..4b5f90fe43518 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -51,7 +51,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; -public class AnthropicService extends SenderService { +public class AnthropicService extends SenderService { public static final String NAME = "anthropic"; private static final String SERVICE_NAME = "Anthropic"; @@ -70,7 +70,7 @@ public AnthropicService( } public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -108,23 +108,6 @@ public void parseRequestConfig( } } - private static AnthropicModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - @Nullable Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - private static AnthropicModel createModel( String inferenceEntityId, TaskType taskType, @@ -145,20 +128,6 @@ private static AnthropicModel createModel( ); } - @Override - public AnthropicModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap); - } - @Override public AnthropicModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -170,14 +139,6 @@ public AnthropicModel buildModelFromConfigAndSecrets(ModelConfigurations config, ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 039deabd32750..2efb260f08259 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -62,7 +62,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -75,7 +74,7 @@ import static org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; -public class AzureAiStudioService extends SenderService implements RerankingInferenceService { +public class AzureAiStudioService extends SenderService implements RerankingInferenceService { public static final String NAME = "azureaistudio"; @@ -106,7 +105,7 @@ public AzureAiStudioService( } public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -207,32 +206,6 @@ public void parseRequestConfig( } } - @Override - public AzureAiStudioModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public AzureAiStudioModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { var model = retrieveModelCreatorFromMapOrThrow( @@ -246,19 +219,6 @@ public AzureAiStudioModel buildModelFromConfigAndSecrets(ModelConfigurations con return model; } - @Override - public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); @@ -332,25 +292,6 @@ private static void checkProviderAndEndpointTypeForTask( } } - private AzureAiStudioModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AzureAiStudioEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 99b03b3e4cd13..4d06b6ff2fcb5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -62,7 +62,6 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -71,7 +70,7 @@ import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; -public class AzureOpenAiService extends SenderService { +public class AzureOpenAiService extends SenderService { public static final String NAME = "azureopenai"; private static final String SERVICE_NAME = "Azure OpenAI"; @@ -104,7 +103,7 @@ public AzureOpenAiService( } public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -150,25 +149,6 @@ public void parseRequestConfig( } } - private static AzureOpenAiModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - private static AzureOpenAiModel createModel( String inferenceEntityId, TaskType taskType, @@ -190,32 +170,6 @@ private static AzureOpenAiModel createModel( ); } - @Override - public AzureOpenAiModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public AzureOpenAiModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -227,19 +181,6 @@ public AzureOpenAiModel buildModelFromConfigAndSecrets(ModelConfigurations confi ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index f3057c7e15173..964c60296cb5b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -59,14 +59,13 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.EMBEDDING_MAX_BATCH_SIZE; -public class CohereService extends SenderService implements RerankingInferenceService { +public class CohereService extends SenderService implements RerankingInferenceService { public static final String NAME = "cohere"; private static final String SERVICE_NAME = "Cohere"; @@ -103,7 +102,7 @@ public CohereService( } public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -149,25 +148,6 @@ public void parseRequestConfig( } } - private static CohereModel createModelWithoutLoggingDeprecations( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - private static CohereModel createModel( String inferenceEntityId, TaskType taskType, @@ -189,32 +169,6 @@ private static CohereModel createModel( ); } - @Override - public CohereModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelWithoutLoggingDeprecations( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public CohereModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -226,26 +180,6 @@ public CohereModel buildModelFromConfigAndSecrets(ModelConfigurations config, Mo ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public CohereModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelWithoutLoggingDeprecations( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null - ); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java index 9ec2b177b38e7..3fe13b5215e08 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java @@ -52,7 +52,7 @@ * Contextual AI inference service for reranking tasks. * This service uses the Contextual AI REST API to perform document reranking. */ -public class ContextualAiService extends SenderService implements RerankingInferenceService { +public class ContextualAiService extends SenderService implements RerankingInferenceService { public static final String NAME = "contextualai"; private static final String SERVICE_NAME = "Contextual AI"; @@ -73,7 +73,7 @@ public ContextualAiService( } public ContextualAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -136,27 +136,6 @@ private static ContextualAiModel createModel( ); } - @Override - public ContextualAiModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - - return createModel( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap, - ConfigurationParseContext.PERSISTENT - ); - } - @Override public ContextualAiModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -168,14 +147,6 @@ public ContextualAiModel buildModelFromConfigAndSecrets(ModelConfigurations conf ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public ContextualAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - return createModel(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, ConfigurationParseContext.PERSISTENT); - } - @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 3816f6ee741fc..37a065e7301d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -66,7 +66,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; -public class CustomService extends SenderService implements RerankingInferenceService { +public class CustomService extends SenderService implements RerankingInferenceService { public static final String NAME = "custom"; private static final String SERVICE_NAME = "Custom"; @@ -100,7 +100,7 @@ public CustomService( } public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -195,25 +195,6 @@ protected void doUnifiedCompletionInfer( throwUnsupportedUnifiedCompletionOperation(NAME); } - private static CustomModel createModelWithoutLoggingDeprecations( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - @Nullable Map secretSettings, - @Nullable ChunkingSettings chunkingSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - secretSettings, - chunkingSettings, - ConfigurationParseContext.PERSISTENT - ); - } - private static CustomModel createModel( String inferenceEntityId, TaskType taskType, @@ -235,29 +216,6 @@ private static CustomModel createModel( ); } - @Override - public CustomModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - - var chunkingSettings = extractPersistentChunkingSettings(config, taskType); - - return createModelWithoutLoggingDeprecations( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap, - chunkingSettings - ); - } - @Override public CustomModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -284,23 +242,6 @@ private static ChunkingSettings extractPersistentChunkingSettings(Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - - var chunkingSettings = extractPersistentChunkingSettings(config, taskType); - - return createModelWithoutLoggingDeprecations( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - null, - chunkingSettings - ); - } - @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index d704d0f4c75dd..c3c640bbe6aa9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -49,8 +49,8 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -public class DeepSeekService extends SenderService { - private static final String NAME = "deepseek"; +public class DeepSeekService extends SenderService { + public static final String NAME = "deepseek"; private static final String CHAT_COMPLETION_ERROR_PREFIX = "deepseek chat completions"; private static final String COMPLETION_ERROR_PREFIX = "deepseek completions"; private static final String SERVICE_NAME = "DeepSeek"; @@ -78,7 +78,7 @@ public DeepSeekService( } public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -181,18 +181,6 @@ private static DeepSeekChatCompletionModel createModel( ); } - @Override - public Model parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - var secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - return createModelFromStorage(inferenceEntityId, taskType, serviceSettingsMap, secretSettingsMap); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -204,21 +192,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { - var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - return createModelFromStorage(modelId, taskType, serviceSettingsMap, null); - } - - private static DeepSeekChatCompletionModel createModelFromStorage( - String inferenceEntityId, - TaskType taskType, - Map serviceSettingsMap, - Map secrets - ) { - return createModel(inferenceEntityId, taskType, serviceSettingsMap, secrets, ConfigurationParseContext.PERSISTENT); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 631a941e8fe12..6d536a86d5a94 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -78,7 +78,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.openai.action.OpenAiActionCreator.USER_ROLE; -public class ElasticInferenceService extends SenderService { +public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; @@ -113,11 +113,6 @@ public class ElasticInferenceService extends SenderService { TEXT_EMBEDDING ); - /** - * The task types that support chunking settings - */ - private static final EnumSet CHUNKING_TASK_TYPES = EnumSet.of(SPARSE_EMBEDDING, TEXT_EMBEDDING, EMBEDDING); - private final Map> modelCreators; private final CCMAuthenticationApplierFactory ccmAuthenticationApplierFactory; private ElasticInferenceServiceActionCreator actionCreator; @@ -139,7 +134,7 @@ public ElasticInferenceService( ClusterService clusterService, CCMAuthenticationApplierFactory ccmAuthApplierFactory ) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, Map.of()); this.ccmAuthenticationApplierFactory = ccmAuthApplierFactory; var elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() @@ -412,7 +407,18 @@ private ElasticInferenceServiceModel createModel( } @Override - public Model parsePersistedConfigWithSecrets(UnparsedModel unparsedModel) { + public ElasticInferenceServiceModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { + return retrieveModelCreatorFromMapOrThrow( + modelCreators, + config.getInferenceEntityId(), + config.getTaskType(), + config.getService(), + ConfigurationParseContext.PERSISTENT + ).createFromModelConfigurationsAndSecrets(config, secrets); + } + + @Override + public ElasticInferenceServiceModel parsePersistedConfig(UnparsedModel unparsedModel) { var config = unparsedModel.settings(); var secrets = unparsedModel.secrets(); var taskType = unparsedModel.taskType(); @@ -420,7 +426,9 @@ public Model parsePersistedConfigWithSecrets(UnparsedModel unparsedModel) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); // These aren't used by EIS endpoints so we'll remove them to avoid potential validation issues removeFromMap(config, ModelConfigurations.TASK_SETTINGS); - removeFromMap(secrets, ModelSecrets.SECRET_SETTINGS); + if (secrets != null) { + removeFromMap(secrets, ModelSecrets.SECRET_SETTINGS); + } ChunkingSettings chunkingSettings = null; if (CHUNKING_TASK_TYPES.contains(taskType)) { @@ -436,36 +444,6 @@ public Model parsePersistedConfigWithSecrets(UnparsedModel unparsedModel) { ); } - @Override - public Model parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - // Once the inference api logic is switched to using the UnparsedModel variants of methods, this method can simply throw - // an exception. Then once all services use the UnparsedModel we can remove this method entirely. - return parsePersistedConfigWithSecrets(new UnparsedModel(inferenceEntityId, taskType, NAME, config, secrets)); - } - - @Override - public ElasticInferenceServiceModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { - return retrieveModelCreatorFromMapOrThrow( - modelCreators, - config.getInferenceEntityId(), - config.getTaskType(), - config.getService(), - ConfigurationParseContext.PERSISTENT - ).createFromModelConfigurationsAndSecrets(config, secrets); - } - - @Override - public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - // Once the inference api logic is switched to using the UnparsedModel variants of methods, this method can simply throw - // an exception. Then once all services use the UnparsedModel we can remove this method entirely. - return parsePersistedConfigWithSecrets(inferenceEntityId, taskType, config, new HashMap<>()); - } - @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersion.minimumCompatible(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index cae486c67db74..7944bba4f121a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -36,6 +36,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.XPackSettings; @@ -502,16 +503,6 @@ private void elserCase( ); } - @Override - public Model parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - return parsePersistedConfig(inferenceEntityId, taskType, config); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { String modelId = config.getServiceSettings().modelId(); @@ -520,28 +511,31 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec } @Override - public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + public Model parsePersistedConfig(UnparsedModel unparsedModel) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(unparsedModel.settings(), ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(unparsedModel.settings(), ModelConfigurations.TASK_SETTINGS); migrateModelVersionToModelId(serviceSettingsMap); ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.SPARSE_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + if (TaskType.TEXT_EMBEDDING.equals(unparsedModel.taskType()) || TaskType.SPARSE_EMBEDDING.equals(unparsedModel.taskType())) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMap(unparsedModel.settings(), ModelConfigurations.CHUNKING_SETTINGS) + ); } String modelId = (String) serviceSettingsMap.get(MODEL_ID); - return retrieveModelCreatorFromListOrThrow(inferenceEntityId, taskType, modelId, NAME).createFromMaps( - inferenceEntityId, - taskType, - NAME, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - ConfigurationParseContext.PERSISTENT - ); + return retrieveModelCreatorFromListOrThrow(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), modelId, NAME) + .createFromMaps( + unparsedModel.inferenceEntityId(), + unparsedModel.taskType(), + NAME, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + null, + ConfigurationParseContext.PERSISTENT + ); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/fireworksai/FireworksAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/fireworksai/FireworksAiService.java index 563db81fadfa7..26d2cd83e15f7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/fireworksai/FireworksAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/fireworksai/FireworksAiService.java @@ -65,7 +65,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -73,7 +72,7 @@ /** * FireworksAI inference service for text embeddings and chat completions. */ -public class FireworksAiService extends SenderService { +public class FireworksAiService extends SenderService { public static final String NAME = "fireworksai"; private static final String SERVICE_NAME = "FireworksAI"; @@ -120,7 +119,7 @@ public FireworksAiService( } public FireworksAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -197,33 +196,6 @@ private static FireworksAiModel createModel( ); } - @Override - public FireworksAiModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = secrets != null ? removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS) : null; - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModel( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap, - ConfigurationParseContext.PERSISTENT - ); - } - @Override public FireworksAiModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -235,11 +207,6 @@ public FireworksAiModel buildModelFromConfigAndSecrets(ModelConfigurations confi ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public FireworksAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - return parsePersistedConfigWithSecrets(inferenceEntityId, taskType, config, null); - } - @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index a1ae10778097a..f7389a1945875 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -60,14 +60,13 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioServiceFields.EMBEDDING_MAX_BATCH_SIZE; -public class GoogleAiStudioService extends SenderService { +public class GoogleAiStudioService extends SenderService { public static final String NAME = "googleaistudio"; @@ -99,7 +98,7 @@ public GoogleAiStudioService( } public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -167,32 +166,6 @@ private static GoogleAiStudioModel createModel( ); } - @Override - public GoogleAiStudioModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public GoogleAiStudioModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -204,38 +177,6 @@ public GoogleAiStudioModel buildModelFromConfigAndSecrets(ModelConfigurations co ).createFromModelConfigurationsAndSecrets(config, secrets); } - private static GoogleAiStudioModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - - @Override - public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index c300a640567f1..7af9d8e2e5e16 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -61,7 +61,6 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -70,7 +69,7 @@ import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX; -public class GoogleVertexAiService extends SenderService implements RerankingInferenceService { +public class GoogleVertexAiService extends SenderService implements RerankingInferenceService { public static final String NAME = "googlevertexai"; @@ -117,7 +116,7 @@ public GoogleVertexAiService( } public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -163,32 +162,6 @@ public void parseRequestConfig( } } - @Override - public Model parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public GoogleVertexAiModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -200,19 +173,6 @@ public GoogleVertexAiModel buildModelFromConfigAndSecrets(ModelConfigurations co ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); @@ -343,25 +303,6 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { } } - private static GoogleVertexAiModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - private static GoogleVertexAiModel createModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/groq/GroqService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/groq/GroqService.java index 988b0c76f4f82..31fb78740b70f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/groq/GroqService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/groq/GroqService.java @@ -57,7 +57,7 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; -public class GroqService extends SenderService { +public class GroqService extends SenderService { public static final String NAME = "groq"; private static final String SERVICE_NAME = "Groq"; @@ -81,7 +81,7 @@ public GroqService( } public GroqService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -134,27 +134,6 @@ public void parseRequestConfig( } } - @Override - public GroqModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - return createModel( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap, - ConfigurationParseContext.PERSISTENT - ); - } - @Override public GroqModel buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -166,14 +145,6 @@ public GroqModel buildModelFromConfigAndSecrets(ModelConfigurations config, Mode ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public GroqModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - return createModel(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, ConfigurationParseContext.PERSISTENT); - } - private static GroqModel createModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index 19f5792a8d944..af694c29effed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -12,17 +12,16 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ModelCreator; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -31,12 +30,11 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -public abstract class HuggingFaceBaseService extends SenderService { +public abstract class HuggingFaceBaseService extends SenderService { /** * The optimal batch size depends on the hardware the model is deployed on. @@ -48,13 +46,10 @@ public abstract class HuggingFaceBaseService extends SenderService { public HuggingFaceBaseService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, - InferenceServiceExtension.InferenceServiceFactoryContext context + ClusterService clusterService, + Map> modelCreators ) { - this(factory, serviceComponents, context.clusterService()); - } - - public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, modelCreators); } @Override @@ -97,58 +92,6 @@ public void parseRequestConfig( } } - @Override - public HuggingFaceModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModel( - new HuggingFaceModelParameters( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap, - ConfigurationParseContext.PERSISTENT - ) - ); - } - - @Override - public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModel( - new HuggingFaceModelParameters( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - null, - ConfigurationParseContext.PERSISTENT - ) - ); - } - protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input); @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 533b931a625e7..6e5271f838e9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -100,7 +100,7 @@ public HuggingFaceService( } public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 6eb0bf56a0985..ad290ced5a6f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -74,7 +74,7 @@ public HuggingFaceElserService( } public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 25b948f73652d..4b04230341bd5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -71,7 +71,7 @@ import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.PROJECT_ID; -public class IbmWatsonxService extends SenderService implements RerankingInferenceService { +public class IbmWatsonxService extends SenderService implements RerankingInferenceService { private static final String SERVICE_NAME = "IBM watsonx"; private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( @@ -113,7 +113,7 @@ public IbmWatsonxService( } public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -181,32 +181,6 @@ private static IbmWatsonxModel createModel( ); } - @Override - public IbmWatsonxModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -228,38 +202,6 @@ public EnumSet supportedTaskTypes() { return SUPPORTED_TASK_TYPES; } - private static IbmWatsonxModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - - @Override - public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); - } - @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersion.minimumCompatible(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index f5540816b38d5..4d26b6de679cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -65,7 +65,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -73,7 +72,7 @@ import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.BaseJinaAIEmbeddingsServiceSettings.updateEmbeddingDetails; -public class JinaAIService extends SenderService implements RerankingInferenceService { +public class JinaAIService extends SenderService implements RerankingInferenceService { public static final TransportVersion JINA_AI_EMBEDDING_REFACTOR = TransportVersion.fromName("jina_ai_embedding_refactor"); @@ -108,7 +107,7 @@ public JinaAIService( } public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -153,25 +152,6 @@ public void parseRequestConfig( } } - private static JinaAIModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - private static JinaAIModel createModel( String inferenceEntityId, TaskType taskType, @@ -193,32 +173,6 @@ private static JinaAIModel createModel( ); } - @Override - public JinaAIModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -230,19 +184,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java index 6d0aff175ac2d..f6d88fa1ef5bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -65,7 +65,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -74,7 +73,7 @@ * LlamaService is an inference service for Llama models, supporting text embedding and chat completion tasks. * It extends SenderService to handle HTTP requests and responses for Llama models. */ -public class LlamaService extends SenderService { +public class LlamaService extends SenderService { public static final String NAME = "llama"; private static final String SERVICE_NAME = "Llama"; private static final TransportVersion ML_INFERENCE_LLAMA_ADDED = TransportVersion.fromName("ml_inference_llama_added"); @@ -115,7 +114,7 @@ public LlamaService( } public LlamaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -305,25 +304,6 @@ public void parseRequestConfig( } } - @Override - public Model parsePersistedConfigWithSecrets( - String modelId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, secretSettingsMap); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -335,36 +315,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - private LlamaModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - - @Override - public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, null); - } - @Override public TransportVersion getMinimalSupportedVersion() { return ML_INFERENCE_LLAMA_ADDED; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 5696f0cafd92f..35d1445976e92 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -61,7 +61,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -72,7 +71,7 @@ * using Mistral models. It supports text embedding, completion, and chat completion tasks. * The service uses MistralActionCreator to create actions for executing inference requests. */ -public class MistralService extends SenderService { +public class MistralService extends SenderService { public static final String NAME = "mistral"; private static final String SERVICE_NAME = "Mistral"; @@ -104,7 +103,7 @@ public MistralService( } public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -235,25 +234,6 @@ public void parseRequestConfig( } } - @Override - public MistralModel parsePersistedConfigWithSecrets( - String modelId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, secretSettingsMap); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -265,19 +245,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, null); - } - @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersion.minimumCompatible(); @@ -308,23 +275,6 @@ private static MistralModel createModel( ); } - private MistralModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - @Override public MistralEmbeddingsModel updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof MistralEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java index d4adfe2acb016..7bdafddcc2a5f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadService.java @@ -55,7 +55,7 @@ * Mixedbread inference service for reranking tasks. * This service uses the Mixedbread REST API to perform document reranking. */ -public class MixedbreadService extends SenderService implements RerankingInferenceService { +public class MixedbreadService extends SenderService implements RerankingInferenceService { public static final String NAME = "mixedbread"; public static final String SERVICE_NAME = "Mixedbread"; @@ -101,7 +101,7 @@ public MixedbreadService( } public MixedbreadService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -140,25 +140,6 @@ public void parseRequestConfig( } } - private MixedbreadModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - /** * Creates an {@link MixedbreadModel} based on the provided parameters. * @@ -192,20 +173,6 @@ protected MixedbreadModel createModel( ); } - @Override - public MixedbreadModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - - return parsePersistedConfigWithSecrets(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, secretSettingsMap); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -217,14 +184,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public MixedbreadModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - return parsePersistedConfigWithSecrets(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaService.java index a5c00c2e52f91..28b9f4e661dae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaService.java @@ -66,7 +66,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -75,7 +74,7 @@ * NvidiaService is an inference service for Nvidia models, supporting text embedding and chat completion tasks. * It extends {@link SenderService} to handle HTTP requests and responses for Nvidia models. */ -public class NvidiaService extends SenderService implements RerankingInferenceService { +public class NvidiaService extends SenderService implements RerankingInferenceService { public static final String NAME = "nvidia"; private static final String SERVICE_NAME = "Nvidia"; @@ -128,7 +127,7 @@ public NvidiaService( } public NvidiaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -320,63 +319,6 @@ public void parseRequestConfig( } } - private NvidiaModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - - private NvidiaModel parsePersistedConfigInternal( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = null; - if (secrets != null) { - secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - } - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - - @Override - public NvidiaModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - return parsePersistedConfigInternal(inferenceEntityId, taskType, config, secrets); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -388,11 +330,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public NvidiaModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - return parsePersistedConfigInternal(inferenceEntityId, taskType, config, null); - } - @Override protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 1965d8f33f156..ece84d968f371 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -66,7 +66,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUnsupportedTaskTypeStatusException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -75,7 +74,7 @@ import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.ORGANIZATION; import static org.elasticsearch.xpack.inference.services.openai.action.OpenAiActionCreator.COMPLETION_ERROR_PREFIX; -public class OpenAiService extends SenderService { +public class OpenAiService extends SenderService { public static final String NAME = "openai"; private static final String SERVICE_NAME = "OpenAI"; @@ -112,7 +111,7 @@ public OpenAiService( } public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -201,31 +200,8 @@ private static OpenAiModel createModel( } @Override - public OpenAiModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - moveModelFromTaskToServiceSettings(taskSettingsMap, serviceSettingsMap); - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); + protected void migrateBetweenTaskAndServiceSettings(Map serviceSettings, Map taskSettings) { + moveModelFromTaskToServiceSettings(taskSettings, serviceSettings); } @Override @@ -239,21 +215,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public OpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - moveModelFromTaskToServiceSettings(taskSettingsMap, serviceSettingsMap); - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java index 7ea6528e58180..f4d6a04fdd034 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiService.java @@ -74,7 +74,7 @@ * using models deployed to OpenShift AI environment. * The service uses {@link OpenShiftAiActionCreator} to create actions for executing inference requests. */ -public class OpenShiftAiService extends SenderService implements RerankingInferenceService { +public class OpenShiftAiService extends SenderService implements RerankingInferenceService { public static final String NAME = "openshift_ai"; /** * The optimal batch size depends on the model deployed in OpenShift AI. @@ -113,7 +113,7 @@ public OpenShiftAiService( } public OpenShiftAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -244,31 +244,6 @@ public void parseRequestConfig( } } - @Override - public OpenShiftAiModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - secretSettingsMap, - taskSettingsMap, - chunkingSettings - ); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -280,20 +255,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public OpenShiftAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - ChunkingSettings chunkingSettingsMap = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettingsMap = ChunkingSettingsBuilder.fromMap( - removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) - ); - } - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, null, taskSettingsMap, chunkingSettingsMap); - } - @Override public TransportVersion getMinimalSupportedVersion() { return OpenShiftAiUtils.ML_INFERENCE_OPENSHIFT_AI_ADDED; @@ -325,25 +286,6 @@ private static OpenShiftAiModel createModel( ); } - private OpenShiftAiModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map secretSettings, - Map taskSettings, - ChunkingSettings chunkingSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - secretSettings, - taskSettings, - chunkingSettings, - ConfigurationParseContext.PERSISTENT - ); - } - @Override public int rerankerWindowSize(String modelId) { // OpenShift AI uses Cohere and JinaAI rerank protocols for reranking diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index e094630d5655a..f1114f9ca62d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker; @@ -123,13 +124,14 @@ public void parseRequestConfig( } @Override - public Model parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - return modelBuilder.fromStorage(inferenceEntityId, taskType, NAME, config, secrets); + public Model parsePersistedConfig(UnparsedModel unparsedModel) { + return modelBuilder.fromStorage( + unparsedModel.inferenceEntityId(), + unparsedModel.taskType(), + NAME, + unparsedModel.settings(), + unparsedModel.secrets() + ); } @Override @@ -137,11 +139,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec return modelBuilder.fromStorage(config, secrets); } - @Override - public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { - return modelBuilder.fromStorage(modelId, taskType, NAME, config, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 076db4287dc4d..87909b33165d4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -55,13 +55,12 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; -public class VoyageAIService extends SenderService implements RerankingInferenceService { +public class VoyageAIService extends SenderService implements RerankingInferenceService { public static final String NAME = "voyageai"; private static final String SERVICE_NAME = "Voyage AI"; @@ -125,7 +124,7 @@ public VoyageAIService( } public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, MODEL_CREATORS); } @Override @@ -170,25 +169,6 @@ public void parseRequestConfig( } } - private static VoyageAIModel createModelFromPersistent( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map taskSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings - ) { - return createModel( - inferenceEntityId, - taskType, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - ConfigurationParseContext.PERSISTENT - ); - } - private static VoyageAIModel createModel( String inferenceEntityId, TaskType taskType, @@ -210,32 +190,6 @@ private static VoyageAIModel createModel( ); } - @Override - public VoyageAIModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent( - inferenceEntityId, - taskType, - serviceSettingsMap, - taskSettingsMap, - chunkingSettings, - secretSettingsMap - ); - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return retrieveModelCreatorFromMapOrThrow( @@ -247,19 +201,6 @@ public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSec ).createFromModelConfigurationsAndSecrets(config, secrets); } - @Override - public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - - ChunkingSettings chunkingSettings = null; - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); - } - - return createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 9a65963053a3e..1720ccaba8399 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -396,7 +396,7 @@ protected void mockService( ) { InferenceService service = mock(); Model model = mockModel(); - when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); + when(service.parsePersistedConfig(any())).thenReturn(model); when(service.name()).thenReturn(serviceId); when(service.canStream(any())).thenReturn(stream); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java index d53c9d5eebbc9..4bd0f6855ac47 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; import org.junit.Before; +import org.mockito.ArgumentCaptor; import java.util.Map; import java.util.Optional; @@ -161,10 +162,18 @@ public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() { verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); verify(mockModelRegistry).containsPreconfiguredInferenceEndpointId(eq(inferenceEndpointId)); - verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + verifyServiceParsedPersistedConfig(mockService, inferenceEndpointId, taskType); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService); } + private void verifyServiceParsedPersistedConfig(InferenceService mockService, String endpointId, TaskType taskType) { + ArgumentCaptor unparsedModelCaptor = ArgumentCaptor.forClass(UnparsedModel.class); + verify(mockService).parsePersistedConfig(unparsedModelCaptor.capture()); + UnparsedModel capturedUnparsedModel = unparsedModelCaptor.getValue(); + assertThat(capturedUnparsedModel.inferenceEntityId(), is(endpointId)); + assertThat(capturedUnparsedModel.taskType(), is(taskType)); + } + public void testDeletesUnparsableEndpoint_WhenForceIsTrue() { var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10); var serviceName = randomAlphanumericOfLength(10); @@ -191,19 +200,20 @@ public void testDeletesUnparsableEndpoint_WhenForceIsTrue() { verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); - verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + verifyServiceParsedPersistedConfig(mockService, inferenceEndpointId, taskType); verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any()); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService); } private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) { + UnparsedModel unparsedModel = new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()); doAnswer(invocationOnMock -> { ActionListener listener = invocationOnMock.getArgument(1); - listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of())); + listener.onResponse(unparsedModel); return Void.TYPE; }).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService) - .parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + .parsePersistedConfig(unparsedModel); when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService)); } @@ -290,7 +300,7 @@ public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); verify(mockModelRegistry).containsPreconfiguredInferenceEndpointId(eq(inferenceEndpointId)); - verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + verifyServiceParsedPersistedConfig(mockService, inferenceEndpointId, taskType); verify(mockService).stop(eq(mockModel), any()); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel); } @@ -320,7 +330,7 @@ public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() { assertTrue(response.isAcknowledged()); verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); - verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + verifyServiceParsedPersistedConfig(mockService, inferenceEndpointId, taskType); verify(mockService).stop(eq(mockModel), any()); verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any()); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel); @@ -333,13 +343,14 @@ private void mockStopDeploymentFails( InferenceService mockService, Model mockModel ) { + UnparsedModel unparsedModel = new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()); doAnswer(invocationOnMock -> { ActionListener listener = invocationOnMock.getArgument(1); - listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of())); + listener.onResponse(unparsedModel); return Void.TYPE; }).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService)); - doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); + doReturn(mockModel).when(mockService).parsePersistedConfig(unparsedModel); doAnswer(invocationOnMock -> { ActionListener listener = invocationOnMock.getArgument(1); listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelActionTests.java index 5ba5678ccb363..6c0a71baf7da6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelActionTests.java @@ -52,11 +52,13 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; import org.junit.Before; +import org.mockito.stubbing.Answer; import java.util.List; import java.util.Map; import java.util.Optional; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; @@ -224,7 +226,7 @@ public void testMasterOperation_UpdatedModelIsEqualToExistingModel_ValidationAnd mockParsePersistedConfigWithSecretsToReturnModel(model); when(service.buildModelFromConfigAndSecrets(any(ModelConfigurations.class), any(ModelSecrets.class))).thenReturn(model); mockModelRegistryGetModelToReturnUnparsedModel(unparsedModel); - mockParsePersistedConfigToReturnModel(model); + when(service.parsePersistedConfig(unparsedModel)).thenReturn(model); var listener = callMasterOperationWithActionFuture(); @@ -341,7 +343,7 @@ public void testMasterOperation_UpdatesModelSettingsSuccessfully() { mockUpdateModelWithEmbeddingDetailsToReturnSameModel(); mockUpdateModelTransactionToReturnBoolean(true, model); mockModelRegistryGetModelToReturnUnparsedModel(unparsedModel); - mockParsePersistedConfigToReturnModel(model); + when(service.parsePersistedConfig(unparsedModel)).thenReturn(model); var listener = callMasterOperationWithActionFuture(); var response = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); @@ -392,8 +394,12 @@ private void mockLicenseStateIsAllowed(boolean value) { } private void mockParsePersistedConfigWithSecretsToReturnModel(GoogleVertexAiEmbeddingsModel model) { - when(service.parsePersistedConfigWithSecrets(eq(INFERENCE_ENTITY_ID_VALUE), eq(TaskType.TEXT_EMBEDDING), anyMap(), anyMap())) - .thenReturn(model); + doAnswer((Answer) invocation -> { + UnparsedModel unparsedModel = invocation.getArgument(0); + assertThat(unparsedModel.inferenceEntityId(), equalTo(INFERENCE_ENTITY_ID_VALUE)); + assertThat(unparsedModel.taskType(), equalTo(model.getTaskType())); + return model; + }).when(service).parsePersistedConfig(any(UnparsedModel.class)); } private void mockServiceRegistryToReturnService(InferenceService service) { @@ -445,10 +451,6 @@ private void mockUpdateModelTransactionToReturnBoolean(boolean result, GoogleVer }).when(mockModelRegistry).updateModelTransaction(any(GoogleVertexAiEmbeddingsModel.class), eq(model), any()); } - private void mockParsePersistedConfigToReturnModel(GoogleVertexAiEmbeddingsModel model) { - when(service.parsePersistedConfig(eq(INFERENCE_ENTITY_ID_VALUE), eq(TaskType.TEXT_EMBEDDING), anyMap())).thenReturn(model); - } - private void verifyNoModelRegistryMutations() { verify(mockModelRegistry, never()).storeModel(any(), any(), any()); verify(mockModelRegistry, never()).storeModels(any(), any(), any()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index e0ecd484006bb..d144f4e9d25c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -48,7 +48,6 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; @@ -1167,11 +1166,10 @@ private static ShardBulkInferenceActionFilter createFilter( }; doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any()); - Answer modelAnswer = invocationOnMock -> { - String inferenceId = (String) invocationOnMock.getArguments()[0]; - return modelMap.get(inferenceId); - }; - doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); + doAnswer(invocationOnMock -> { + UnparsedModel unparsedModel = invocationOnMock.getArgument(0); + return modelMap.get(unparsedModel.inferenceEntityId()); + }).when(inferenceService).parsePersistedConfig(any(UnparsedModel.class)); InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java index f10c22d8ef1c3..5b736fd6b1aa5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceBaseTests.java @@ -110,7 +110,7 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } - protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); + protected abstract SenderService createService(ThreadPool threadPool, HttpClientManager clientManager); protected abstract Map createServiceSettingsMap(TaskType taskType); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedModelCreationTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedModelCreationTests.java index 9bd00f0eafeef..f855c6e1da462 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedModelCreationTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedModelCreationTests.java @@ -48,7 +48,7 @@ public record TestCase( ) {} private record ModelCreatorParams( - SenderService service, + SenderService service, Utils.ModelConfigAndSecrets modelConfigAndSecrets, TestConfiguration testConfiguration ) {} @@ -183,13 +183,13 @@ public void testBuildModelFromConfigAndSecrets() { } } - private void assertSuccessfulModelCreation(SenderService service, Utils.ModelConfigAndSecrets persistedConfig) { + private void assertSuccessfulModelCreation(SenderService service, Utils.ModelConfigAndSecrets persistedConfig) { var model = testCase.modelCreator.buildModel(new ModelCreatorParams(service, persistedConfig, testConfiguration)); testConfiguration.commonConfig().assertModel(model, testCase.expectedTaskType, true, ConfigurationParseContext.PERSISTENT); } - private void assertFailedModelCreation(SenderService service, Utils.ModelConfigAndSecrets modelConfigAndSecrets) { + private void assertFailedModelCreation(SenderService service, Utils.ModelConfigAndSecrets modelConfigAndSecrets) { var exception = expectThrows( ElasticsearchStatusException.class, () -> testCase.modelCreator.buildModel(new ModelCreatorParams(service, modelConfigAndSecrets, testConfiguration)) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedParsingTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedParsingTests.java index ce2d030f7d7d2..8d0a862dcd605 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedParsingTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceParameterizedParsingTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.Utils; import org.junit.Assume; @@ -58,7 +59,7 @@ public record TestCase( ) {} private record ServiceParserParams( - SenderService service, + SenderService service, Utils.PersistedConfig persistedConfig, AbstractInferenceServiceBaseTests.TestConfiguration testConfiguration ) {} @@ -184,9 +185,13 @@ public static Iterable parameters() { null ), params -> params.service.parsePersistedConfig( - "id", - params.testConfiguration.commonConfig().unsupportedTaskType(), - params.persistedConfig.config() + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().unsupportedTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), // We expect failure, so the expected task type is irrelevant null @@ -208,9 +213,13 @@ public static Iterable parameters() { return persistedConfigMap; }, params -> params.service.parsePersistedConfig( - "id", - params.testConfiguration.commonConfig().targetTaskType(), - params.persistedConfig.config() + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().targetTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), // Test expected task type is the target task type null @@ -233,9 +242,13 @@ public static Iterable parameters() { ); }, params -> params.service.parsePersistedConfig( - "id", - params.testConfiguration.commonConfig().targetTaskType(), - params.persistedConfig.config() + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().targetTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), // Test expected task type is the target task type null @@ -259,9 +272,13 @@ public static Iterable parameters() { ); }, params -> params.service.parsePersistedConfig( - "id", - params.testConfiguration.commonConfig().targetTaskType(), - params.persistedConfig.config() + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().targetTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), // Test expected task type is the target task type null @@ -342,11 +359,14 @@ public static Iterable parameters() { testConfiguration.commonConfig().createTaskSettingsMap(testConfiguration.commonConfig().targetTaskType()), testConfiguration.commonConfig().createSecretSettingsMap() ), - params -> params.service.parsePersistedConfigWithSecrets( - "id", - params.testConfiguration.commonConfig().unsupportedTaskType(), - params.persistedConfig.config(), - params.persistedConfig.secrets() + params -> params.service.parsePersistedConfig( + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().unsupportedTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), // We expect failure, so the expected task type is irrelevant null @@ -367,11 +387,14 @@ public static Iterable parameters() { persistedConfigMap.config().put("extra_key", "value"); return persistedConfigMap; }, - params -> params.service.parsePersistedConfigWithSecrets( - "id", - params.testConfiguration.commonConfig().targetTaskType(), - params.persistedConfig.config(), - params.persistedConfig.secrets() + params -> params.service.parsePersistedConfig( + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().targetTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), null // Test expected task type is the target task type @@ -393,11 +416,14 @@ public static Iterable parameters() { testConfiguration.commonConfig().createSecretSettingsMap() ); }, - params -> params.service.parsePersistedConfigWithSecrets( - "id", - params.testConfiguration.commonConfig().targetTaskType(), - params.persistedConfig.config(), - params.persistedConfig.secrets() + params -> params.service.parsePersistedConfig( + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().targetTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), // Test expected task type is the target task type null @@ -420,11 +446,14 @@ public static Iterable parameters() { testConfiguration.commonConfig().createSecretSettingsMap() ); }, - params -> params.service.parsePersistedConfigWithSecrets( - "id", - params.testConfiguration.commonConfig().targetTaskType(), - params.persistedConfig.config(), - params.persistedConfig.secrets() + params -> params.service.parsePersistedConfig( + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().targetTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), // Test expected task type is the target task type null @@ -446,11 +475,14 @@ public static Iterable parameters() { secretSettingsMap ); }, - params -> params.service.parsePersistedConfigWithSecrets( - "id", - params.testConfiguration.commonConfig().targetTaskType(), - params.persistedConfig.config(), - params.persistedConfig.secrets() + params -> params.service.parsePersistedConfig( + new UnparsedModel( + "id", + params.testConfiguration.commonConfig().targetTaskType(), + "test_service", + params.persistedConfig.config(), + params.persistedConfig.secrets() + ) ), // Test expected task type is the target task type null @@ -472,15 +504,14 @@ private static Function persistedConfi } private static ServiceParser getServiceParser(TaskType taskType) { - return params -> params.service.parsePersistedConfig("id", taskType, params.persistedConfig.config()); + return params -> params.service.parsePersistedConfig( + new UnparsedModel("id", taskType, "test_service", params.persistedConfig.config(), params.persistedConfig.secrets()) + ); } private static ServiceParser getServiceParserWithSecrets(TaskType taskType) { - return params -> params.service.parsePersistedConfigWithSecrets( - "id", - taskType, - params.persistedConfig.config(), - params.persistedConfig.secrets() + return params -> params.service.parsePersistedConfig( + new UnparsedModel("id", taskType, "test_service", params.persistedConfig.config(), params.persistedConfig.secrets()) ); } @@ -503,7 +534,7 @@ public void testPersistedConfig() throws Exception { } } - private void assertFailedParse(SenderService service, Utils.PersistedConfig persistedConfig) { + private void assertFailedParse(SenderService service, Utils.PersistedConfig persistedConfig) { var exception = expectThrows( ElasticsearchStatusException.class, () -> testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration)) @@ -517,7 +548,7 @@ private void assertFailedParse(SenderService service, Utils.PersistedConfig pers ); } - private void assertSuccessfulParse(SenderService service, Utils.PersistedConfig persistedConfig) { + private void assertSuccessfulParse(SenderService service, Utils.PersistedConfig persistedConfig) { var model = testCase.serviceParser.parseConfigs(new ServiceParserParams(service, persistedConfig, testConfiguration)); if (persistedConfig.config().containsKey(ModelConfigurations.CHUNKING_SETTINGS)) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index f9be8e3c40e9c..28c2fd5d509b0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.completion.ContentObjects; import org.elasticsearch.inference.completion.Message; import org.elasticsearch.test.ESTestCase; @@ -361,9 +362,9 @@ public static Sender createMockSender() { return sender; } - private static class TestSenderService extends SenderService { + private static class TestSenderService extends SenderService { TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { - super(factory, serviceComponents, clusterService); + super(factory, serviceComponents, clusterService, Map.of()); } @Override @@ -415,23 +416,13 @@ public void parseRequestConfig( parsedModelListener.onResponse(null); } - @Override - public Model parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - return null; - } - @Override public Model buildModelFromConfigAndSecrets(ModelConfigurations config, ModelSecrets secrets) { return null; } @Override - public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + public Model parsePersistedConfig(UnparsedModel unparsedModel) { return null; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java index 167b73d9047cf..c0aa1494e3daf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; -import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel; import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModelTests; @@ -97,7 +96,7 @@ public static AbstractInferenceServiceTests.TestConfiguration createTestConfigur new CommonConfig(TaskType.COMPLETION, TaskType.TEXT_EMBEDDING, EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)) { @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + protected Ai21Service createService(ThreadPool threadPool, HttpClientManager clientManager) { return Ai21ServiceTests.createService(threadPool, clientManager); } @@ -195,7 +194,7 @@ private static void assertChatCompletionModel(Model model, boolean modelIncludes assertThat(customModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); } - public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + public static Ai21Service createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); return new Ai21Service(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 49089553f258b..7354d2604af57 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; @@ -226,17 +227,21 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting ) ) { var model = service.parsePersistedConfig( - "id", - TaskType.TEXT_EMBEDDING, - getPersistedConfigMap( - AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap( - SERVICE_ID_VALUE, - HOST_VALUE, - WORKSPACE_NAME_VALUE - ), - AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), - createRandomChunkingSettingsMap() - ).config() + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchService.NAME, + getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap( + SERVICE_ID_VALUE, + HOST_VALUE, + WORKSPACE_NAME_VALUE + ), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ).config(), + null + ) ); assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -257,16 +262,20 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting ) ) { var model = service.parsePersistedConfig( - "id", - TaskType.TEXT_EMBEDDING, - getPersistedConfigMap( - AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap( - SERVICE_ID_VALUE, - HOST_VALUE, - WORKSPACE_NAME_VALUE - ), - AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) - ).config() + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchService.NAME, + getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap( + SERVICE_ID_VALUE, + HOST_VALUE, + WORKSPACE_NAME_VALUE + ), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ).config(), + null + ) ); assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -278,7 +287,7 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try ( var service = new AlibabaCloudSearchService( mock(HttpRequestSender.Factory.class), @@ -292,11 +301,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun createRandomChunkingSettingsMap(), getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); @@ -309,7 +321,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try ( var service = new AlibabaCloudSearchService( mock(HttpRequestSender.Factory.class), @@ -323,11 +335,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun createRandomChunkingSettingsMap(), getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index d102c711401c2..c35846e9c7bbf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.completion.ContentString; import org.elasticsearch.inference.completion.Message; import org.elasticsearch.threadpool.ThreadPool; @@ -557,18 +558,21 @@ public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IO } } - public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddingsModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnAmazonBedrockEmbeddingsModel() throws IOException { try (var service = createAmazonBedrockService()) { var settingsMap = createEmbeddingsRequestSettingsMap(REGION_VALUE, MODEL_VALUE, "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap(ACCESS_KEY_VALUE, SECRET_KEY_VALUE); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -577,13 +581,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddings assertThat(settings.region(), is(REGION_VALUE)); assertThat(settings.modelId(), is(MODEL_VALUE)); assertThat(settings.provider(), is(AMAZON_BEDROCK_PROVIDER_VALUE)); - var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + var secretSettings = model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is(ACCESS_KEY_VALUE)); assertThat(secretSettings.secretKey().toString(), is(SECRET_KEY_VALUE)); } } - public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() + throws IOException { try (var service = createAmazonBedrockService()) { var settingsMap = createEmbeddingsRequestSettingsMap(REGION_VALUE, MODEL_VALUE, "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap(ACCESS_KEY_VALUE, SECRET_KEY_VALUE); @@ -595,11 +600,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddings secretSettingsMap ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -609,13 +617,13 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddings assertThat(settings.modelId(), is(MODEL_VALUE)); assertThat(settings.provider(), is(AMAZON_BEDROCK_PROVIDER_VALUE)); assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + var secretSettings = model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is(ACCESS_KEY_VALUE)); assertThat(secretSettings.secretKey().toString(), is(SECRET_KEY_VALUE)); } } - public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() + public void testParsePersistedConfig_WithSecrets_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createAmazonBedrockService()) { var settingsMap = createEmbeddingsRequestSettingsMap(REGION_VALUE, MODEL_VALUE, "amazontitan", null, false, null, null); @@ -623,11 +631,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddings var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -637,13 +648,13 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddings assertThat(settings.modelId(), is(MODEL_VALUE)); assertThat(settings.provider(), is(AMAZON_BEDROCK_PROVIDER_VALUE)); assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + var secretSettings = model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is(ACCESS_KEY_VALUE)); assertThat(secretSettings.secretKey().toString(), is(SECRET_KEY_VALUE)); } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createAmazonBedrockService()) { var settingsMap = createChatCompletionRequestSettingsMap(REGION_VALUE, MODEL_VALUE, "amazontitan"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap(ACCESS_KEY_VALUE, SECRET_KEY_VALUE); @@ -652,11 +663,14 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.SPARSE_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ) ); @@ -668,7 +682,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createAmazonBedrockService()) { var settingsMap = createEmbeddingsRequestSettingsMap(REGION_VALUE, MODEL_VALUE, "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap(ACCESS_KEY_VALUE, SECRET_KEY_VALUE); @@ -676,11 +690,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -689,13 +706,13 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists assertThat(settings.region(), is(REGION_VALUE)); assertThat(settings.modelId(), is(MODEL_VALUE)); assertThat(settings.provider(), is(AMAZON_BEDROCK_PROVIDER_VALUE)); - var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + var secretSettings = model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is(ACCESS_KEY_VALUE)); assertThat(secretSettings.secretKey().toString(), is(SECRET_KEY_VALUE)); } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createAmazonBedrockService()) { var settingsMap = createEmbeddingsRequestSettingsMap(REGION_VALUE, MODEL_VALUE, "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap(ACCESS_KEY_VALUE, SECRET_KEY_VALUE); @@ -703,11 +720,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -716,13 +736,13 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists assertThat(settings.region(), is(REGION_VALUE)); assertThat(settings.modelId(), is(MODEL_VALUE)); assertThat(settings.provider(), is(AMAZON_BEDROCK_PROVIDER_VALUE)); - var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + var secretSettings = model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is(ACCESS_KEY_VALUE)); assertThat(secretSettings.secretKey().toString(), is(SECRET_KEY_VALUE)); } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createAmazonBedrockService()) { var settingsMap = createEmbeddingsRequestSettingsMap(REGION_VALUE, MODEL_VALUE, "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap(ACCESS_KEY_VALUE, SECRET_KEY_VALUE); @@ -730,11 +750,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.secrets().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -743,13 +766,13 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe assertThat(settings.region(), is(REGION_VALUE)); assertThat(settings.modelId(), is(MODEL_VALUE)); assertThat(settings.provider(), is(AMAZON_BEDROCK_PROVIDER_VALUE)); - var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + var secretSettings = model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is(ACCESS_KEY_VALUE)); assertThat(secretSettings.secretKey().toString(), is(SECRET_KEY_VALUE)); } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createAmazonBedrockService()) { var settingsMap = createEmbeddingsRequestSettingsMap(REGION_VALUE, MODEL_VALUE, "amazontitan", null, false, null, null); settingsMap.put("extra_key", "value"); @@ -757,11 +780,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -770,7 +796,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe assertThat(settings.region(), is(REGION_VALUE)); assertThat(settings.modelId(), is(MODEL_VALUE)); assertThat(settings.provider(), is(AMAZON_BEDROCK_PROVIDER_VALUE)); - var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + var secretSettings = model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is(ACCESS_KEY_VALUE)); assertThat(secretSettings.secretKey().toString(), is(SECRET_KEY_VALUE)); } @@ -798,7 +824,9 @@ private void assertNotThrowWhenAnExtraKeyExistsInTaskSettings_WithSecrets(TaskTy var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); - var model = service.parsePersistedConfigWithSecrets("id", taskType, persistedConfig.config(), persistedConfig.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel("id", taskType, AmazonBedrockService.NAME, persistedConfig.config(), persistedConfig.secrets()) + ); assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); @@ -811,7 +839,7 @@ private void assertNotThrowWhenAnExtraKeyExistsInTaskSettings_WithSecrets(TaskTy assertThat(taskSettings.topP(), is(expectedTopP)); assertThat(taskSettings.topK(), is(expectedTopK)); assertThat(taskSettings.maxNewTokens(), is(expectedMaxNewTokens)); - var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + var secretSettings = model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is(ACCESS_KEY_VALUE)); assertThat(secretSettings.secretKey().toString(), is(SECRET_KEY_VALUE)); } @@ -823,7 +851,9 @@ public void testParsePersistedConfig_CreatesAnAmazonBedrockEmbeddingsModel() thr var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AmazonBedrockService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -847,7 +877,9 @@ public void testParsePersistedConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenCh secretSettingsMap ); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AmazonBedrockService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -867,7 +899,9 @@ public void testParsePersistedConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenCh var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AmazonBedrockService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -900,7 +934,9 @@ private void assertCreatesAnAmazonBedrockChatCompletionModel(TaskType taskType, var secretSettingsMap = getAmazonBedrockSecretSettingsMap(ACCESS_KEY_VALUE, SECRET_KEY_VALUE); var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); - var model = service.parsePersistedConfig("id", taskType, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel("id", taskType, AmazonBedrockService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); @@ -925,7 +961,15 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + AmazonBedrockService.NAME, + persistedConfig.config(), + new HashMap<>() + ) + ) ); assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [amazonbedrock] service")); @@ -944,7 +988,9 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AmazonBedrockService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -965,7 +1011,9 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AmazonBedrockService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); @@ -998,7 +1046,9 @@ private void assertNotThrowWhenAnExtraKeyExistsInTaskSettings(TaskType taskType, var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, taskType, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, taskType, AmazonBedrockService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 1f08f9c073b23..e7521aa197dbb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; @@ -223,7 +224,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } } - public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesACompletionModel() throws IOException { try (var service = createServiceWithMockSender()) { var persistedConfig = getPersistedConfigMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_NAME_VALUE)), @@ -231,11 +232,14 @@ public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + AnthropicService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); @@ -247,7 +251,7 @@ public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createServiceWithMockSender()) { var persistedConfig = getPersistedConfigMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_NAME_VALUE)), @@ -256,11 +260,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + AnthropicService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); @@ -272,7 +279,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createServiceWithMockSender()) { var secretSettingsMap = getSecretSettingsMap(API_KEY_VALUE); secretSettingsMap.put("extra_key", "value"); @@ -283,11 +290,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + AnthropicService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); @@ -299,7 +309,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createServiceWithMockSender()) { Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_NAME_VALUE)); serviceSettingsMap.put("extra_key", "value"); @@ -310,11 +320,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + AnthropicService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); @@ -326,7 +339,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createServiceWithMockSender()) { Map taskSettings = AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3); taskSettings.put("extra_key", "value"); @@ -337,11 +350,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + AnthropicService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); @@ -360,7 +376,9 @@ public void testParsePersistedConfig_CreatesACompletionModel() throws IOExceptio AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, AnthropicService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); @@ -379,7 +397,9 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, AnthropicService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); @@ -400,7 +420,9 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSe AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, AnthropicService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); @@ -418,7 +440,9 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_NAME_VALUE)), taskSettings); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, AnthropicService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index ae16acc5c2b33..4ebc7bddce89f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -668,11 +669,8 @@ public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() thr getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); @@ -689,7 +687,7 @@ public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() thr } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( getEmbeddingsServiceSettingsMap(URL_VALUE, "openai", "token", 1024, true, 512, null), @@ -698,11 +696,8 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); @@ -720,7 +715,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( getEmbeddingsServiceSettingsMap(URL_VALUE, "openai", "token", 1024, true, 512, null), @@ -728,11 +723,8 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); @@ -750,7 +742,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testParsePersistedConfig_CreatesAnAzureAiStudioChatCompletionModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnAzureAiStudioChatCompletionModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( getChatCompletionServiceSettingsMap(URL_VALUE, "openai", "token"), @@ -758,7 +750,9 @@ public void testParsePersistedConfig_CreatesAnAzureAiStudioChatCompletionModel() getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, TaskType.COMPLETION, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.COMPLETION, AzureAiStudioService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class)); @@ -773,7 +767,7 @@ public void testParsePersistedConfig_CreatesAnAzureAiStudioChatCompletionModel() } } - public void testParsePersistedConfig_CreatesAnAzureAiStudioRerankModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnAzureAiStudioRerankModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( getRerankServiceSettingsMap(URL_VALUE, "cohere", "token"), @@ -781,7 +775,9 @@ public void testParsePersistedConfig_CreatesAnAzureAiStudioRerankModel() throws getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, TaskType.RERANK, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.RERANK, AzureAiStudioService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); @@ -817,7 +813,7 @@ public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOExcep } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( getChatCompletionServiceSettingsMap(URL_VALUE, "openai", "token"), @@ -827,11 +823,14 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.SPARSE_EMBEDDING, - config.config(), - config.secrets() + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + AzureAiStudioService.NAME, + config.config(), + config.secrets() + ) ) ); @@ -843,7 +842,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createService()) { var serviceSettings = getEmbeddingsServiceSettingsMap(URL_VALUE, "openai", "token", 1024, true, 512, null); var taskSettings = getEmbeddingsTaskSettingsMap("user"); @@ -851,18 +850,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); config.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInEmbeddingServiceSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenExtraKeyExistsInEmbeddingServiceSettingsMap() throws IOException { try (var service = createService()) { var serviceSettings = getEmbeddingsServiceSettingsMap(URL_VALUE, "openai", "token", 1024, true, 512, null); serviceSettings.put("extra_key", "value"); @@ -871,18 +867,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInEmbeddingSe var secretSettings = getSecretSettingsMap(API_KEY_VALUE); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException { try (var service = createService()) { var serviceSettings = getEmbeddingsServiceSettingsMap(URL_VALUE, "openai", "token", 1024, true, 512, null); var taskSettings = getEmbeddingsTaskSettingsMap("user"); @@ -891,18 +884,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbedding var secretSettings = getSecretSettingsMap(API_KEY_VALUE); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { try (var service = createService()) { var serviceSettings = getEmbeddingsServiceSettingsMap(URL_VALUE, "openai", "token", 1024, true, 512, null); var taskSettings = getEmbeddingsTaskSettingsMap("user"); @@ -911,18 +901,16 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbedding var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionServiceSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionServiceSettingsMap() + throws IOException { try (var service = createService()) { var serviceSettings = getChatCompletionServiceSettingsMap(URL_VALUE, "openai", "token"); serviceSettings.put("extra_key", "value"); @@ -930,13 +918,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompl var secretSettings = getSecretSettingsMap(API_KEY_VALUE); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, TaskType.COMPLETION, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.COMPLETION, AzureAiStudioService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionTaskSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionTaskSettingsMap() throws IOException { try (var service = createService()) { var serviceSettings = getChatCompletionServiceSettingsMap(URL_VALUE, "openai", "token"); var taskSettings = getChatCompletionTaskSettingsMap(1.0, 2.0, true, 512); @@ -944,13 +934,16 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompl var secretSettings = getSecretSettingsMap(API_KEY_VALUE); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, TaskType.COMPLETION, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.COMPLETION, AzureAiStudioService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() + throws IOException { try (var service = createService()) { var serviceSettings = getChatCompletionServiceSettingsMap(URL_VALUE, "openai", "token"); var taskSettings = getChatCompletionTaskSettingsMap(1.0, 2.0, true, 512); @@ -958,13 +951,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompl secretSettings.put("extra_key", "value"); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, TaskType.COMPLETION, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.COMPLETION, AzureAiStudioService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException { try (var service = createService()) { var serviceSettings = getRerankServiceSettingsMap(URL_VALUE, "cohere", "token"); serviceSettings.put("extra_key", "value"); @@ -972,13 +967,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankSer var secretSettings = getSecretSettingsMap(API_KEY_VALUE); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, TaskType.RERANK, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.RERANK, AzureAiStudioService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException { try (var service = createService()) { var serviceSettings = getRerankServiceSettingsMap(URL_VALUE, "cohere", "token"); var taskSettings = getRerankTaskSettingsMap(true, 2); @@ -986,13 +983,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankTas var secretSettings = getSecretSettingsMap(API_KEY_VALUE); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, TaskType.RERANK, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.RERANK, AzureAiStudioService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException { try (var service = createService()) { var serviceSettings = getRerankServiceSettingsMap(URL_VALUE, "cohere", "token"); var taskSettings = getRerankTaskSettingsMap(true, 2); @@ -1000,7 +999,9 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankSec secretSettings.put("extra_key", "value"); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, TaskType.RERANK, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.RERANK, AzureAiStudioService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); } @@ -1014,7 +1015,9 @@ public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() thro Map.of() ); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, config.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), null) + ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); @@ -1038,7 +1041,9 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting Map.of() ); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, config.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), null) + ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); @@ -1062,7 +1067,9 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting Map.of() ); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, config.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, AzureAiStudioService.NAME, config.config(), null) + ); assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); @@ -1086,7 +1093,9 @@ public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() Map.of() ); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.COMPLETION, config.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.COMPLETION, AzureAiStudioService.NAME, config.config(), null) + ); assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class)); @@ -1109,7 +1118,9 @@ public void testParsePersistedConfig_WithoutSecretsCreatesRerankModel() throws I Map.of() ); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.RERANK, config.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.RERANK, AzureAiStudioService.NAME, config.config(), null) + ); assertThat(model, instanceOf(AzureAiStudioRerankModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index cb28da6430d39..ee7b66ecf7ce9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.completion.ContentString; import org.elasticsearch.inference.completion.Message; import org.elasticsearch.test.ESTestCase; @@ -362,7 +363,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { } } - public void testParsePersistedConfigWithSecrets_CreatesAnAzureOpenAiEmbeddingsModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnAzureOpenAiEmbeddingsModel() throws IOException { try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), @@ -370,11 +371,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAzureOpenAiEmbeddingsMo getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -390,7 +394,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAzureOpenAiEmbeddingsMo } } - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), @@ -399,11 +403,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWh getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -420,7 +427,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWh } } - public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnOpenAiEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), @@ -428,11 +435,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWh getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -449,7 +459,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWh } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), @@ -459,11 +469,14 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.SPARSE_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ) ); @@ -478,7 +491,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), @@ -487,11 +500,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -507,7 +523,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createAzureOpenAiService()) { var secretSettingsMap = getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null); secretSettingsMap.put("extra_key", "value"); @@ -518,11 +534,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -538,7 +557,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), @@ -547,11 +566,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe ); persistedConfig.secrets().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -567,7 +589,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createAzureOpenAiService()) { var serviceSettingsMap = getPersistentAzureOpenAiServiceSettingsMap( RESOURCE_NAME_VALUE, @@ -584,11 +606,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -604,7 +629,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createAzureOpenAiService()) { var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE); taskSettingsMap.put("extra_key", "value"); @@ -615,11 +640,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -642,7 +670,15 @@ public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModel() throw getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -663,7 +699,15 @@ public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModelWhenChun createRandomChunkingSettingsMap() ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -684,7 +728,15 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingS getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -707,7 +759,15 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + new HashMap<>() + ) + ) ); assertThat( @@ -729,7 +789,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -755,7 +823,15 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE)); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); @@ -778,7 +854,15 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( taskSettingsMap ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + AzureOpenAiService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index a120c8447ae98..8cea856242361 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -37,6 +37,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -362,7 +363,7 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr } } - public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesACohereEmbeddingsModel() throws IOException { try (var service = createCohereService()) { var persistedConfig = getPersistedConfigMap( CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(URL_VALUE, MODEL_VALUE, null), @@ -370,11 +371,14 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel() getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -387,7 +391,7 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModel() } } - public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesACohereEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createCohereService()) { var persistedConfig = getPersistedConfigMap( CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(URL_VALUE, MODEL_VALUE, null), @@ -396,11 +400,14 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWhe getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -414,7 +421,7 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWhe } } - public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesACohereEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createCohereService()) { var persistedConfig = getPersistedConfigMap( CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(URL_VALUE, MODEL_VALUE, null), @@ -422,11 +429,14 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWhe getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -440,7 +450,7 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWhe } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createCohereService()) { var persistedConfig = getPersistedConfigMap( CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(URL_VALUE, null, null), @@ -450,11 +460,14 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.SPARSE_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.SPARSE_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ) ); @@ -463,7 +476,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM } } - public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesACohereEmbeddingsModelWithoutUrl() throws IOException { try (var service = createCohereService()) { var persistedConfig = getPersistedConfigMap( CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null), @@ -471,11 +484,14 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWit getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -487,7 +503,7 @@ public void testParsePersistedConfigWithSecrets_CreatesACohereEmbeddingsModelWit } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createCohereService()) { var persistedConfig = getPersistedConfigMap( CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(URL_VALUE, MODEL_VALUE, DenseVectorFieldMapper.ElementType.BYTE), @@ -496,11 +512,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -514,7 +533,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createCohereService()) { var secretSettingsMap = getSecretSettingsMap(API_KEY_VALUE); secretSettingsMap.put("extra_key", "value"); @@ -525,11 +544,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -541,7 +563,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createCohereService()) { var persistedConfig = getPersistedConfigMap( CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(URL_VALUE, MODEL_VALUE, null), @@ -550,11 +572,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe ); persistedConfig.secrets().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -567,18 +592,21 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createCohereService()) { var serviceSettingsMap = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(URL_VALUE, null, null); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap(API_KEY_VALUE)); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -590,7 +618,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createCohereService()) { var taskSettingsMap = getTaskSettingsMap(InputType.SEARCH, null); taskSettingsMap.put("extra_key", "value"); @@ -601,11 +629,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.TEXT_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -625,7 +656,9 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModel() throws IOEx getTaskSettingsMap(null, Truncation.NONE) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, CohereService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -645,7 +678,9 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModelWhenChunkingSe createRandomChunkingSettingsMap() ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, CohereService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -665,7 +700,9 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModelWhenChunkingSe getTaskSettingsMap(null, Truncation.NONE) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, CohereService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -687,7 +724,15 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID, + TaskType.SPARSE_EMBEDDING, + CohereService.NAME, + persistedConfig.config(), + new HashMap<>() + ) + ) ); assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [cohere] service")); @@ -702,7 +747,9 @@ public void testParsePersistedConfig_CreatesACohereEmbeddingsModelWithoutUrl() t getTaskSettingsMap(null, null) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, CohereService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -723,7 +770,9 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, CohereService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -741,7 +790,9 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap(InputType.SEARCH, null)); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, CohereService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); @@ -762,7 +813,9 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( taskSettingsMap ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ENTITY_ID, TaskType.TEXT_EMBEDDING, CohereService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(CohereEmbeddingsModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 944c9ba1d15b7..f0d7bf9f84ae1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -42,7 +42,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; @@ -93,7 +92,7 @@ public static TestConfiguration createTestConfiguration() { EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.COMPLETION) ) { @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + protected CustomService createService(ThreadPool threadPool, HttpClientManager clientManager) { return CustomServiceTests.createService(threadPool, clientManager); } @@ -226,7 +225,7 @@ private static void assertRerankModel(Model model, boolean modelIncludesSecrets) assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(RerankResponseParser.class)); } - public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + public static CustomService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); return new CustomService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 079ecb8e8d85b..b1a6628f8c691 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.completion.ContentString; import org.elasticsearch.inference.completion.Message; import org.elasticsearch.test.http.MockResponse; @@ -52,6 +53,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -406,10 +408,22 @@ private Map map(String json) throws IOException { } private DeepSeekChatCompletionModel parsePersistedConfig(String json) throws IOException { + Map asMap = map(json); + Map serviceSettings = new HashMap<>(); + if (asMap.containsKey(ModelConfigurations.SERVICE_SETTINGS)) { + serviceSettings.put(ModelConfigurations.SERVICE_SETTINGS, asMap.get(ModelConfigurations.SERVICE_SETTINGS)); + } + Map secretSettings = null; + if (asMap.containsKey(ModelSecrets.SECRET_SETTINGS)) { + secretSettings = new HashMap<>(); + secretSettings.put(ModelSecrets.SECRET_SETTINGS, asMap.get(ModelSecrets.SECRET_SETTINGS)); + } try (var service = createService()) { - var model = service.parsePersistedConfig("inference-id", TaskType.CHAT_COMPLETION, map(json)); + var model = service.parsePersistedConfig( + new UnparsedModel("inference-id", TaskType.CHAT_COMPLETION, DeepSeekService.NAME, serviceSettings, secretSettings) + ); assertThat(model, isA(DeepSeekChatCompletionModel.class)); - return (DeepSeekChatCompletionModel) model; + return model; } } @@ -428,7 +442,7 @@ private InferenceEventsAssertion doUnifiedCompletionInfer() throws Exception { } private DeepSeekChatCompletionModel createModel(DeepSeekService service, TaskType taskType) throws URISyntaxException, IOException { - var model = service.parsePersistedConfigWithSecrets("inference-id", taskType, map(Strings.format(""" + var model = service.parsePersistedConfig(new UnparsedModel("inference-id", taskType, DeepSeekService.NAME, map(Strings.format(""" { "service_settings": { "model_id": "some-cool-model", @@ -441,9 +455,9 @@ private DeepSeekChatCompletionModel createModel(DeepSeekService service, TaskTyp "api_key": "12345" } } - """)); + """))); assertThat(model, isA(DeepSeekChatCompletionModel.class)); - return (DeepSeekChatCompletionModel) model; + return model; } public void testBuildModelFromConfigAndSecrets_ChatCompletion() throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 816d46b0d9956..4682a8b8554b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -328,43 +328,20 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap public void testParseStoredConfig_CreatesASparseEmbeddingModel() throws IOException { try (var service = createServiceWithMockSender()) { - { - var mockedPersistedConfig = getBaseSparseEmbeddingConfig(); - - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( - new UnparsedModel( - INFERENCE_ENTITY_ID, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - mockedPersistedConfig.config(), - mockedPersistedConfig.secrets() - ) - ), - ElserModels.ELSER_V2_MODEL - ); - } + var mockedPersistedConfig = getBaseSparseEmbeddingConfig(); - { - var mockedPersistedConfig = getBaseSparseEmbeddingConfig(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( + assertSparseEmbeddingModelFromPersistedConfig( + service.parsePersistedConfig( + new UnparsedModel( INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, mockedPersistedConfig.config(), mockedPersistedConfig.secrets() - ), - ElserModels.ELSER_V2_MODEL - ); - } - - { - var mockedPersistedConfig = getBaseSparseEmbeddingConfig(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, mockedPersistedConfig.config()), - ElserModels.ELSER_V2_MODEL - ); - } + ) + ), + ElserModels.ELSER_V2_MODEL + ); } } @@ -407,7 +384,9 @@ private void testParseStoredConfig_CreatesADenseEmbeddingsModel(TaskType taskTyp chunkingSettingsMap, Map.of() ); - var model = service.parsePersistedConfigWithSecrets("id", taskType, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel("id", taskType, ElasticInferenceService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(ElasticInferenceServiceDenseEmbeddingsModel.class)); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); @@ -423,45 +402,22 @@ private void testParseStoredConfig_CreatesADenseEmbeddingsModel(TaskType taskTyp } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createServiceWithMockSender()) { - { - var mockedPersistedConfig = getExtraKeyInConfig(); - - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( - new UnparsedModel( - INFERENCE_ENTITY_ID, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - mockedPersistedConfig.config(), - mockedPersistedConfig.secrets() - ) - ), - ElserModels.ELSER_V2_MODEL - ); - } + var mockedPersistedConfig = getExtraKeyInConfig(); - { - var mockedPersistedConfig = getExtraKeyInConfig(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( + assertSparseEmbeddingModelFromPersistedConfig( + service.parsePersistedConfig( + new UnparsedModel( INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, mockedPersistedConfig.config(), mockedPersistedConfig.secrets() - ), - ElserModels.ELSER_V2_MODEL - ); - } - - { - var mockedPersistedConfig = getExtraKeyInConfig(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, mockedPersistedConfig.config()), - ElserModels.ELSER_V2_MODEL - ); - } + ) + ), + ElserModels.ELSER_V2_MODEL + ); } } @@ -477,43 +433,20 @@ private static Utils.PersistedConfig getExtraKeyInConfig() { public void testParseStoredConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createServiceWithMockSender()) { - { - var mockedPersistedConfig = getExtraKeyInServiceSettings(); - - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( - new UnparsedModel( - INFERENCE_ENTITY_ID, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - mockedPersistedConfig.config(), - mockedPersistedConfig.secrets() - ) - ), - ElserModels.ELSER_V2_MODEL - ); - } + var mockedPersistedConfig = getExtraKeyInServiceSettings(); - { - var mockedPersistedConfig = getExtraKeyInServiceSettings(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( + assertSparseEmbeddingModelFromPersistedConfig( + service.parsePersistedConfig( + new UnparsedModel( INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, mockedPersistedConfig.config(), mockedPersistedConfig.secrets() - ), - ElserModels.ELSER_V2_MODEL - ); - } - - { - var mockedPersistedConfig = getExtraKeyInServiceSettings(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, mockedPersistedConfig.config()), - ElserModels.ELSER_V2_MODEL - ); - } + ) + ), + ElserModels.ELSER_V2_MODEL + ); } } @@ -525,43 +458,20 @@ private static Utils.PersistedConfig getExtraKeyInServiceSettings() { public void testParseStoredConfig_DoesNotThrowWhenRateLimitFieldExistsInServiceSettings() throws IOException { try (var service = createServiceWithMockSender()) { - { - var mockedPersistedConfig = getRateLimitInServiceSettings(); - - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( - new UnparsedModel( - INFERENCE_ENTITY_ID, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - mockedPersistedConfig.config(), - mockedPersistedConfig.secrets() - ) - ), - ElserModels.ELSER_V2_MODEL - ); - } + var mockedPersistedConfig = getRateLimitInServiceSettings(); - { - var mockedPersistedConfig = getRateLimitInServiceSettings(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( + assertSparseEmbeddingModelFromPersistedConfig( + service.parsePersistedConfig( + new UnparsedModel( INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, mockedPersistedConfig.config(), mockedPersistedConfig.secrets() - ), - ElserModels.ELSER_V2_MODEL - ); - } - - { - var mockedPersistedConfig = getRateLimitInServiceSettings(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, mockedPersistedConfig.config()), - ElserModels.ELSER_V2_MODEL - ); - } + ) + ), + ElserModels.ELSER_V2_MODEL + ); } } @@ -579,43 +489,20 @@ private static Utils.PersistedConfig getRateLimitInServiceSettings() { public void testParseStoredConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createServiceWithMockSender()) { - { - var mockedPersistedConfig = getExtraKeyInTaskSettings(); - - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( - new UnparsedModel( - INFERENCE_ENTITY_ID, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - mockedPersistedConfig.config(), - mockedPersistedConfig.secrets() - ) - ), - ElserModels.ELSER_V2_MODEL - ); - } + var mockedPersistedConfig = getExtraKeyInTaskSettings(); - { - var mockedPersistedConfig = getExtraKeyInTaskSettings(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( + assertSparseEmbeddingModelFromPersistedConfig( + service.parsePersistedConfig( + new UnparsedModel( INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, mockedPersistedConfig.config(), mockedPersistedConfig.secrets() - ), - ElserModels.ELSER_V2_MODEL - ); - } - - { - var mockedPersistedConfig = getExtraKeyInTaskSettings(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, mockedPersistedConfig.config()), - ElserModels.ELSER_V2_MODEL - ); - } + ) + ), + ElserModels.ELSER_V2_MODEL + ); } } @@ -626,43 +513,20 @@ private static Utils.PersistedConfig getExtraKeyInTaskSettings() { public void testParseStoredConfig_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createServiceWithMockSender()) { - { - var mockedPersistedConfig = getExtraKeyInSecretsSettings(); - - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( - new UnparsedModel( - INFERENCE_ENTITY_ID, - TaskType.SPARSE_EMBEDDING, - ElasticInferenceService.NAME, - mockedPersistedConfig.config(), - mockedPersistedConfig.secrets() - ) - ), - ElserModels.ELSER_V2_MODEL - ); - } + var mockedPersistedConfig = getExtraKeyInSecretsSettings(); - { - var mockedPersistedConfig = getExtraKeyInSecretsSettings(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfigWithSecrets( + assertSparseEmbeddingModelFromPersistedConfig( + service.parsePersistedConfig( + new UnparsedModel( INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, mockedPersistedConfig.config(), mockedPersistedConfig.secrets() - ), - ElserModels.ELSER_V2_MODEL - ); - } - - { - var mockedPersistedConfig = getExtraKeyInSecretsSettings(); - assertSparseEmbeddingModelFromPersistedConfig( - service.parsePersistedConfig(INFERENCE_ENTITY_ID, TaskType.SPARSE_EMBEDDING, mockedPersistedConfig.config()), - ElserModels.ELSER_V2_MODEL - ); - } + ) + ), + ElserModels.ELSER_V2_MODEL + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 40035c0fda7e7..356f8e483d727 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -45,6 +45,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; import org.elasticsearch.inference.telemetry.InferenceStatsTests; import org.elasticsearch.rest.RestStatus; @@ -780,8 +781,15 @@ public void testParsePersistedConfig() { ) ) ); + UnparsedModel unparsedModel = new UnparsedModel( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + ElasticsearchInternalService.NAME, + settings, + Collections.emptyMap() + ); - var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings); + var model = service.parsePersistedConfig(unparsedModel); assertThat(model, instanceOf(ElserInternalModel.class)); ElserInternalModel elserInternalModel = (ElserInternalModel) model; assertThat(elserInternalModel.getServiceSettings().modelId(), is(".elser_model_2")); @@ -804,12 +812,16 @@ public void testParsePersistedConfig() { ) ) ); - - var exception = expectThrows( - IllegalArgumentException.class, - () -> service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings) + UnparsedModel unparsedModel = new UnparsedModel( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + ElasticsearchInternalService.NAME, + settings, + Collections.emptyMap() ); + var exception = expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(unparsedModel)); + assertThat(exception.getMessage(), containsString(randomInferenceEntityId)); } @@ -832,12 +844,15 @@ public void testParsePersistedConfig() { ) ) ); - - CustomElandEmbeddingModel parsedModel = (CustomElandEmbeddingModel) service.parsePersistedConfig( + UnparsedModel unparsedModel = new UnparsedModel( randomInferenceEntityId, TaskType.TEXT_EMBEDDING, - settings + ElasticsearchInternalService.NAME, + settings, + Collections.emptyMap() ); + + CustomElandEmbeddingModel parsedModel = (CustomElandEmbeddingModel) service.parsePersistedConfig(unparsedModel); var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( 1, 4, @@ -880,6 +895,13 @@ public void testParsePersistedConfig() { ) ) ); + UnparsedModel unparsedModel = new UnparsedModel( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + ElasticsearchInternalService.NAME, + settings, + Collections.emptyMap() + ); var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( 1, @@ -888,11 +910,7 @@ public void testParsePersistedConfig() { null ); - MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig( - randomInferenceEntityId, - TaskType.TEXT_EMBEDDING, - settings - ); + MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig(unparsedModel); assertEquals( new MultilingualE5SmallModel( randomInferenceEntityId, @@ -919,7 +937,14 @@ public void testParsePersistedConfig() { settings.put("not_a_valid_config_setting", randomAlphaOfLength(10)); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(randomInferenceEntityId, taskType, settings)); + UnparsedModel unparsedModel = new UnparsedModel( + randomInferenceEntityId, + taskType, + ElasticsearchInternalService.NAME, + settings, + Collections.emptyMap() + ); + expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(unparsedModel)); } // Invalid service settings @@ -940,7 +965,14 @@ public void testParsePersistedConfig() { ) ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(randomInferenceEntityId, taskType, settings)); + UnparsedModel unparsedModel = new UnparsedModel( + randomInferenceEntityId, + taskType, + ElasticsearchInternalService.NAME, + settings, + Collections.emptyMap() + ); + expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(unparsedModel)); } } @@ -1565,8 +1597,15 @@ public void testParsePersistedConfig_Rerank() { settings.put(ElasticsearchInternalServiceSettings.MODEL_ID, "foo"); var returnDocs = randomBoolean(); settings.put(ModelConfigurations.TASK_SETTINGS, new HashMap<>(Map.of(RerankTaskSettings.RETURN_DOCUMENTS, returnDocs))); + UnparsedModel unparsedModel = new UnparsedModel( + randomInferenceEntityId, + TaskType.RERANK, + ElasticsearchInternalService.NAME, + settings, + Collections.emptyMap() + ); - var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.RERANK, settings); + var model = service.parsePersistedConfig(unparsedModel); assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class)); assertEquals(returnDocs, ((RerankTaskSettings) model.getTaskSettings()).returnDocuments()); } @@ -1589,8 +1628,15 @@ public void testParsePersistedConfig_Rerank() { ) ); settings.put(ElasticsearchInternalServiceSettings.MODEL_ID, "foo"); + UnparsedModel unparsedModel = new UnparsedModel( + randomInferenceEntityId, + TaskType.RERANK, + ElasticsearchInternalService.NAME, + settings, + Collections.emptyMap() + ); - var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.RERANK, settings); + var model = service.parsePersistedConfig(unparsedModel); assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class)); assertTrue(((RerankTaskSettings) model.getTaskSettings()).returnDocuments()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/fireworksai/FireworksAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/fireworksai/FireworksAiServiceTests.java index 5186bce78e8f8..b2fae79606111 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/fireworksai/FireworksAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/fireworksai/FireworksAiServiceTests.java @@ -29,7 +29,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.fireworksai.completion.FireworksAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.fireworksai.completion.FireworksAiChatCompletionServiceSettings; @@ -79,7 +78,7 @@ public static TestConfiguration createTestConfiguration() { EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION) ) { @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + protected FireworksAiService createService(ThreadPool threadPool, HttpClientManager clientManager) { return FireworksAiServiceTests.createService(threadPool, clientManager); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 6b82002fe48f4..78f014bf0aeee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -307,7 +308,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } } - public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioCompletionModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAGoogleAiStudioCompletionModel() throws IOException { try (var service = createGoogleAiStudioService()) { var persistedConfig = getPersistedConfigMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID_VALUE)), @@ -315,11 +316,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioCompletion getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); @@ -331,7 +335,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioCompletion } } - public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddingsModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAGoogleAiStudioEmbeddingsModel() throws IOException { try (var service = createGoogleAiStudioService()) { var persistedConfig = getPersistedConfigMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID_VALUE)), @@ -339,11 +343,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddings getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleAiStudioService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); @@ -355,7 +362,8 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddings } } - public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsProvided() + throws IOException { try (var service = createGoogleAiStudioService()) { var persistedConfig = getPersistedConfigMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID_VALUE)), @@ -364,11 +372,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddings getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleAiStudioService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); @@ -381,7 +392,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddings } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createGoogleAiStudioService()) { var persistedConfig = getPersistedConfigMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID_VALUE)), @@ -389,11 +400,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleAiStudioService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); @@ -406,7 +420,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createGoogleAiStudioService()) { var persistedConfig = getPersistedConfigMap( new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID_VALUE)), @@ -415,11 +429,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); @@ -431,7 +448,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createGoogleAiStudioService()) { var secretSettingsMap = getSecretSettingsMap(API_KEY_VALUE); secretSettingsMap.put("extra_key", "value"); @@ -442,11 +459,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); @@ -458,18 +478,21 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createGoogleAiStudioService()) { Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID_VALUE)); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap(API_KEY_VALUE)); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); @@ -481,7 +504,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createGoogleAiStudioService()) { Map taskSettings = getTaskSettingsMapEmpty(); taskSettings.put("extra_key", "value"); @@ -492,11 +515,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); @@ -515,7 +541,15 @@ public void testParsePersistedConfig_CreatesAGoogleAiStudioCompletionModel() thr getTaskSettingsMapEmpty() ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); @@ -534,7 +568,15 @@ public void testParsePersistedConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenCh createRandomChunkingSettingsMap() ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleAiStudioService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); @@ -553,7 +595,15 @@ public void testParsePersistedConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenCh getTaskSettingsMapEmpty() ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleAiStudioService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); @@ -573,7 +623,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); @@ -591,7 +649,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSe var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty()); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); @@ -609,7 +675,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID_VALUE)), taskSettings); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.COMPLETION, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + GoogleAiStudioService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 2a0b1b3375fa7..2d55fd39d202f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; @@ -464,7 +465,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } } - public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesGoogleVertexAiEmbeddingsModel() throws IOException { var projectId = "project"; var location = "location"; var modelId = "model"; @@ -493,11 +494,14 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsM getSecretSettingsMap(serviceAccountJson) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); @@ -512,7 +516,7 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsM } } - public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatCompletionModel() throws IOException, URISyntaxException { + public void testParsePersistedConfig_WithSecrets_CreatesGoogleVertexAiChatCompletionModel() throws IOException, URISyntaxException { var projectId = "project"; var location = "location"; var modelId = "model"; @@ -548,11 +552,14 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatComplet getSecretSettingsMap(serviceAccountJson) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.CHAT_COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + CHAT_COMPLETION, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiChatCompletionModel.class)); @@ -572,7 +579,8 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiChatComplet } } - public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvided() + throws IOException { var projectId = "project"; var location = "location"; var modelId = "model"; @@ -602,11 +610,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddings getSecretSettingsMap(serviceAccountJson) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); @@ -622,7 +633,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddings } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { var projectId = "project"; var location = "location"; var modelId = "model"; @@ -651,11 +662,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun getSecretSettingsMap(serviceAccountJson) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); @@ -671,7 +685,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiRerankModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesGoogleVertexAiRerankModel() throws IOException { var projectId = "project"; var topN = 1; var serviceAccountJson = """ @@ -687,11 +701,14 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiRerankModel getSecretSettingsMap(serviceAccountJson) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.RERANK, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.RERANK, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiRerankModel.class)); @@ -703,7 +720,7 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiRerankModel } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var projectId = "project"; var location = "location"; var modelId = "model"; @@ -733,11 +750,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); @@ -752,7 +772,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { var projectId = "project"; var location = "location"; var modelId = "model"; @@ -784,11 +804,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); @@ -803,7 +826,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { var projectId = "project"; var location = "location"; var modelId = "model"; @@ -835,11 +858,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists getSecretSettingsMap(serviceAccountJson) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); @@ -854,7 +880,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { var projectId = "project"; var location = "location"; var modelId = "model"; @@ -886,11 +912,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists getSecretSettingsMap(serviceAccountJson) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); @@ -929,7 +958,15 @@ public void testParsePersistedConfig_CreatesAGoogleVertexAiEmbeddingsModelWhenCh createRandomChunkingSettingsMap() ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); @@ -966,7 +1003,15 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting getTaskSettingsMap(autoTruncate, null) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + GoogleVertexAiService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/groq/GroqServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/groq/GroqServiceTests.java index c1fb159aaa49d..1e28ab47c608b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/groq/GroqServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/groq/GroqServiceTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; @@ -120,11 +121,8 @@ public void testParsePersistedConfigWithSecretsUsesSecretSettings() throws Excep Map secrets = new HashMap<>(); secrets.put(ModelSecrets.SECRET_SETTINGS, new HashMap<>(Map.of(DefaultSecretSettings.API_KEY, "persisted-secret"))); - GroqChatCompletionModel model = (GroqChatCompletionModel) service.parsePersistedConfigWithSecrets( - "groq-test", - TaskType.CHAT_COMPLETION, - config, - secrets + GroqChatCompletionModel model = (GroqChatCompletionModel) service.parsePersistedConfig( + new UnparsedModel("groq-test", TaskType.CHAT_COMPLETION, GroqService.NAME, config, secrets) ); assertTrue(model.getSecretSettings().apiKey().equals("persisted-secret")); assertThat(model.getServiceSettings().modelId(), equalTo("persisted-model")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 952e831d0f2dd..b8d1bd9f65443 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.completion.ContentString; import org.elasticsearch.inference.completion.Message; import org.elasticsearch.rest.RestStatus; @@ -697,7 +698,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModel() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap(URL_VALUE), @@ -705,11 +706,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throw getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -720,7 +724,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throw } } - public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesACompletionModel() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap(URL_VALUE), @@ -728,11 +732,14 @@ public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.COMPLETION, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.COMPLETION, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceChatCompletionModel.class)); @@ -743,7 +750,7 @@ public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap(URL_VALUE), @@ -752,11 +759,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -768,7 +778,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap(URL_VALUE), @@ -776,11 +786,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -792,7 +805,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testParsePersistedConfigWithSecrets_CreatesAnElserModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAnElserModel() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap(URL_VALUE), @@ -800,11 +813,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAnElserModel() throws IOE getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.SPARSE_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceElserModel.class)); @@ -815,7 +831,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnElserModel() throws IOE } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap(URL_VALUE), @@ -824,11 +840,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -839,18 +858,21 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createHuggingFaceService()) { var secretSettingsMap = getSecretSettingsMap(API_KEY_VALUE); secretSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(URL_VALUE), new HashMap<>(), secretSettingsMap); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -861,7 +883,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( getServiceSettingsMap(URL_VALUE), @@ -870,11 +892,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.secrets().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -885,18 +910,21 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createHuggingFaceService()) { var serviceSettingsMap = getServiceSettingsMap(URL_VALUE); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(serviceSettingsMap, new HashMap<>(), getSecretSettingsMap(API_KEY_VALUE)); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -907,7 +935,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createHuggingFaceService()) { var taskSettingsMap = new HashMap(); taskSettingsMap.put("extra_key", "value"); @@ -918,11 +946,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -937,7 +968,15 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModel() throws IOExcepti try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(URL_VALUE)); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -951,7 +990,15 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(URL_VALUE), createRandomChunkingSettingsMap()); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -966,7 +1013,15 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(URL_VALUE)); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -981,7 +1036,15 @@ public void testParsePersistedConfig_CreatesAnElserModel() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(URL_VALUE), new HashMap<>()); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.SPARSE_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(HuggingFaceElserModel.class)); @@ -996,7 +1059,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(URL_VALUE)); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -1013,7 +1084,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSe var persistedConfig = getPersistedConfigMap(serviceSettingsMap); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); @@ -1030,7 +1109,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(URL_VALUE), taskSettingsMap, null); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + HuggingFaceService.NAME, + persistedConfig.config(), + null + ) + ); assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 0defd481e322f..a811e86bdf602 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -294,7 +295,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } } - public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAIbmWatsonxEmbeddingsModel() throws IOException { try (var service = createIbmWatsonxService()) { var persistedConfig = getPersistedConfigMap( new HashMap<>( @@ -313,11 +314,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsMode getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + IbmWatsonxService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class)); @@ -333,7 +337,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsMode } } - public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createIbmWatsonxService()) { var persistedConfig = getPersistedConfigMap( new HashMap<>( @@ -353,11 +357,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsMode getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + IbmWatsonxService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class)); @@ -373,7 +380,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsMode } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createIbmWatsonxService()) { var persistedConfig = getPersistedConfigMap( new HashMap<>( @@ -393,11 +400,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + IbmWatsonxService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class)); @@ -412,7 +422,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createIbmWatsonxService()) { var secretSettingsMap = getSecretSettingsMap(API_KEY_VALUE); secretSettingsMap.put("extra_key", "value"); @@ -434,11 +444,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + IbmWatsonxService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class)); @@ -453,7 +466,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createIbmWatsonxService()) { Map serviceSettingsMap = new HashMap<>( Map.of( @@ -471,11 +484,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap(API_KEY_VALUE)); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + IbmWatsonxService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class)); @@ -490,7 +506,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { var modelId = "model"; var apiKey = "apiKey"; @@ -515,11 +531,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - "id", - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + "id", + TaskType.TEXT_EMBEDDING, + IbmWatsonxService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class)); @@ -553,7 +572,9 @@ public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunki null ); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel("id", TaskType.TEXT_EMBEDDING, IbmWatsonxService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class)); @@ -587,7 +608,9 @@ public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunki null ); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel("id", TaskType.TEXT_EMBEDDING, IbmWatsonxService.NAME, persistedConfig.config(), null) + ); assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index a555adfc55484..6f17b54893455 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.http.MockRequest; import org.elasticsearch.test.http.MockResponse; @@ -434,14 +435,14 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } public void testParsePersistedConfigWithSecrets_createsEmbeddingsModel_textEmbedding() throws IOException { - testParsePersistedConfigWithSecrets_createsEmbeddingModel(TEXT_EMBEDDING); + testParsePersistedConfig_WithSecrets_createsEmbeddingModel(TEXT_EMBEDDING); } public void testParsePersistedConfigWithSecrets_createsEmbeddingsModel_embedding() throws IOException { - testParsePersistedConfigWithSecrets_createsEmbeddingModel(TaskType.EMBEDDING); + testParsePersistedConfig_WithSecrets_createsEmbeddingModel(TaskType.EMBEDDING); } - private void testParsePersistedConfigWithSecrets_createsEmbeddingModel(TaskType taskType) throws IOException { + private void testParsePersistedConfig_WithSecrets_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { var modelName = randomAlphanumericOfLength(8); var requestsPerMinute = randomNonNegativeInt(); @@ -477,11 +478,14 @@ private void testParsePersistedConfigWithSecrets_createsEmbeddingModel(TaskType getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - taskType, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + taskType, + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertEmbeddingModelSettings( @@ -501,7 +505,7 @@ private void testParsePersistedConfigWithSecrets_createsEmbeddingModel(TaskType } } - public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { var modelName = randomAlphanumericOfLength(8); var requestsPerMinute = randomNonNegativeInt(); @@ -515,11 +519,14 @@ public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOEx getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.RERANK, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.RERANK, + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertRerankModelSettings( @@ -533,14 +540,14 @@ public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOEx } public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingsModel_textEmbedding() throws IOException { - testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(TEXT_EMBEDDING); + testParsePersistedConfig_WithSecrets_onlyRequiredSettings_createsEmbeddingModel(TEXT_EMBEDDING); } public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingsModel_embedding() throws IOException { - testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(TaskType.EMBEDDING); + testParsePersistedConfig_WithSecrets_onlyRequiredSettings_createsEmbeddingModel(TaskType.EMBEDDING); } - private void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmbeddingModel(TaskType taskType) throws IOException { + private void testParsePersistedConfig_WithSecrets_onlyRequiredSettings_createsEmbeddingModel(TaskType taskType) throws IOException { try (var service = createJinaAIService()) { var modelName = randomAlphanumericOfLength(8); var apiKey = randomAlphanumericOfLength(8); @@ -553,11 +560,14 @@ private void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmb getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - taskType, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + taskType, + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertEmbeddingModelSettings( @@ -577,25 +587,28 @@ private void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsEmb } } - public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsRerankModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_onlyRequiredSettings_createsRerankModel() throws IOException { try (var service = createJinaAIService()) { var modelName = randomAlphanumericOfLength(8); var apiKey = randomAlphanumericOfLength(8); var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), Map.of(), getSecretSettingsMap(apiKey)); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.RERANK, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.RERANK, + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertRerankModelSettings(model, modelName, DEFAULT_RATE_LIMIT_SETTINGS, apiKey, JinaAIRerankTaskSettings.EMPTY_SETTINGS); } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorWithUnsupportedTaskType() throws IOException { + public void testParsePersistedConfig_WithSecrets_ThrowsErrorWithUnsupportedTaskType() throws IOException { try (var service = createJinaAIService()) { var unsupportedTaskType = randomValueOtherThanMany( t -> service.supportedTaskTypes().contains(t), @@ -605,11 +618,14 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorWithUnsupportedTaskTy var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - unsupportedTaskType, - persistedConfig.config(), - persistedConfig.secrets() + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + unsupportedTaskType, + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ) ); @@ -621,7 +637,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorWithUnsupportedTaskTy } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { String modelName = MODEL_NAME_VALUE; String apiKey = "secret"; @@ -636,7 +652,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createJinaAIService()) { String modelName = MODEL_NAME_VALUE; String apiKey = "secret"; @@ -677,7 +693,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInChunkingSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInChunkingSettings() throws IOException { try (var service = createJinaAIService()) { String modelName = MODEL_NAME_VALUE; String apiKey = "secret"; @@ -689,11 +705,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInCh getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - randomEmbeddingTaskType(), - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + randomEmbeddingTaskType(), + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model.getServiceSettings().modelId(), is(modelName)); @@ -744,7 +763,15 @@ private void testParsePersistedConfig_createsEmbeddingModel(TaskType taskType) t null ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, taskType, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + taskType, + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); assertEmbeddingModelSettings( model, @@ -776,7 +803,15 @@ public void testParsePersistedConfig_createsRerankModel() throws IOException { null ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.RERANK, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.RERANK, + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); assertRerankModelSettings( model, @@ -798,7 +833,15 @@ public void testParsePersistedConfig_ThrowsErrorWithUnsupportedTaskType() throws var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, unsupportedTaskType, persistedConfig.config()) + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + unsupportedTaskType, + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ) ); assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [jinaai] service")); @@ -853,7 +896,15 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInChunkingSetti null ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, randomEmbeddingTaskType(), persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + randomEmbeddingTaskType(), + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); assertThat(model.getServiceSettings().modelId(), is(modelName)); assertThat(model.apiKey().toString(), is("")); @@ -2215,11 +2266,14 @@ private static void assertParsePersistedConfigWithSecretsMinimalSettings( String modelName, String apiKey ) { - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - randomFrom(service.supportedTaskTypes()), - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + randomFrom(service.supportedTaskTypes()), + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model.getServiceSettings().modelId(), is(modelName)); @@ -2232,9 +2286,13 @@ private static void assertParsePersistedConfigMinimalSettings( String modelName ) { var model = service.parsePersistedConfig( - INFERENCE_ENTITY_ID_VALUE, - randomFrom(service.supportedTaskTypes()), - persistedConfig.config() + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + randomFrom(service.supportedTaskTypes()), + JinaAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model.getServiceSettings().modelId(), is(modelName)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index 4278354525cc7..7d7688bcb6038 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -55,7 +55,6 @@ import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; -import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; @@ -120,7 +119,7 @@ public static TestConfiguration createTestConfiguration() { new CommonConfig(TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION)) { @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + protected LlamaService createService(ThreadPool threadPool, HttpClientManager clientManager) { return LlamaServiceTests.createService(threadPool, clientManager); } @@ -240,7 +239,7 @@ private static void assertChatCompletionModel(Model model, boolean modelIncludes assertThat(llamaModel.getTaskType(), Matchers.is(CHAT_COMPLETION)); } - public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + public static LlamaService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); return new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index b865ea41603c5..950fc74007269 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.completion.ContentString; import org.elasticsearch.inference.completion.Message; import org.elasticsearch.rest.RestStatus; @@ -639,11 +640,8 @@ public void testParsePersistedConfig_CreatesAMistralEmbeddingsModel() throws IOE getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, MistralService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(MistralEmbeddingsModel.class)); @@ -668,7 +666,9 @@ private void testParsePersistedConfig_CreatesAMistralModel(String modelId, TaskT try (var service = createService()) { var config = getPersistedConfigMap(getServiceSettingsMap(modelId), getTaskSettingsMap(), getSecretSettingsMap(API_KEY_VALUE)); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, chatCompletion, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, chatCompletion, MistralService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(MistralChatCompletionModel.class)); @@ -687,11 +687,8 @@ public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingS getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, MistralService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(MistralEmbeddingsModel.class)); @@ -713,11 +710,8 @@ public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingS getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, MistralService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(MistralEmbeddingsModel.class)); @@ -750,7 +744,7 @@ public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOExcep } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( getEmbeddingsServiceSettingsMap(null, null), @@ -760,11 +754,8 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.SPARSE_EMBEDDING, - config.config(), - config.secrets() + () -> service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.SPARSE_EMBEDDING, MistralService.NAME, config.config(), config.secrets()) ) ); @@ -808,7 +799,9 @@ private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig( var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings); config.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, chatCompletion, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, chatCompletion, MistralService.NAME, config.config(), config.secrets()) + ); assertThat(model, matcher); } @@ -850,7 +843,9 @@ private void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSet var secretSettings = getSecretSettingsMap(API_KEY_VALUE); var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, chatCompletion, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, chatCompletion, MistralService.NAME, config.config(), config.secrets()) + ); assertThat(model, matcher); } @@ -865,11 +860,8 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbedding var secretSettings = getSecretSettingsMap(API_KEY_VALUE); var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.TEXT_EMBEDDING, - config.config(), - config.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, MistralService.NAME, config.config(), config.secrets()) ); assertThat(model, instanceOf(MistralEmbeddingsModel.class)); @@ -912,7 +904,9 @@ private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSett var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets(INFERENCE_ID_VALUE, chatCompletion, config.config(), config.secrets()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, chatCompletion, MistralService.NAME, config.config(), config.secrets()) + ); assertThat(model, matcher); } @@ -920,9 +914,11 @@ private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSett public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException { try (var service = createService()) { - var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), Map.of()); + var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), null); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, config.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, MistralService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(MistralEmbeddingsModel.class)); @@ -939,10 +935,12 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), createRandomChunkingSettingsMap(), - Map.of() + null ); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, config.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, MistralService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(MistralEmbeddingsModel.class)); @@ -956,9 +954,11 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createService()) { - var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), Map.of()); + var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), null); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, config.config()); + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, TaskType.TEXT_EMBEDDING, MistralService.NAME, config.config(), config.secrets()) + ); assertThat(model, instanceOf(MistralEmbeddingsModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceTests.java index a780e77af11e6..4b468c746ccb8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mixedbread/MixedbreadServiceTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -40,7 +41,6 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModel; import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankModelTests; import org.elasticsearch.xpack.inference.services.mixedbread.rerank.MixedbreadRerankServiceSettings; @@ -110,7 +110,7 @@ public static TestConfiguration createTestConfiguration() { return new TestConfiguration.Builder(new CommonConfig(RERANK, COMPLETION, EnumSet.of(RERANK)) { @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + protected MixedbreadService createService(ThreadPool threadPool, HttpClientManager clientManager) { return MixedbreadServiceTests.createService(threadPool, clientManager); } @@ -214,7 +214,7 @@ private static void assertRerankModel(Model model, boolean modelIncludesSecrets) assertThat(mixedbreadModel.getTaskType(), Matchers.is(RERANK)); } - public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + public static MixedbreadService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); return new MixedbreadService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } @@ -332,7 +332,7 @@ public void testParseRequestConfig_onlyRequiredSettings_createsRerankModel() thr } } - public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_createsRerankModel() throws IOException { try (var service = createMixedbreadService()) { var modelName = randomAlphanumericOfLength(8); var requestsPerMinute = randomNonNegativeInt(); @@ -346,11 +346,14 @@ public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOEx getSecretSettingsMap(apiKey) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.RERANK, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.RERANK, + MixedbreadService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model.getSecretSettings().apiKey().toString(), is(apiKey)); @@ -364,18 +367,21 @@ public void testParsePersistedConfigWithSecrets_createsRerankModel() throws IOEx } } - public void testParsePersistedConfigWithSecrets_onlyRequiredSettings_createsRerankModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_onlyRequiredSettings_createsRerankModel() throws IOException { try (var service = createMixedbreadService()) { var modelName = randomAlphanumericOfLength(8); var apiKey = randomAlphanumericOfLength(8); var persistedConfig = getPersistedConfigMap(getServiceSettingsMap(modelName, null), Map.of(), getSecretSettingsMap(apiKey)); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - TaskType.RERANK, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.RERANK, + MixedbreadService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat(model.getSecretSettings().apiKey().toString(), is(apiKey)); @@ -402,7 +408,15 @@ public void testParsePersistedConfig_createsRerankModel() throws IOException { null ); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, TaskType.RERANK, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.RERANK, + MixedbreadService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); assertRerankModelSettings( model, @@ -507,7 +521,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInSecretsSe var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, RERANK, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.RERANK, + MixedbreadService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); @@ -527,7 +549,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSe var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, RERANK, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.RERANK, + MixedbreadService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); @@ -547,7 +577,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfig(INFERENCE_ID_VALUE, RERANK, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ID_VALUE, + TaskType.RERANK, + MixedbreadService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); @@ -557,7 +595,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSetti } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); var secretSettings = getSecretSettingsMap(API_KEY); @@ -567,11 +605,8 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - RERANK, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, RERANK, MixedbreadService.NAME, persistedConfig.config(), persistedConfig.secrets()) ); assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); @@ -583,7 +618,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); var secretSettings = getSecretSettingsMap(API_KEY); @@ -593,11 +628,8 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - RERANK, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, RERANK, MixedbreadService.NAME, persistedConfig.config(), persistedConfig.secrets()) ); assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); @@ -609,7 +641,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { var serviceSettings = getServiceSettingsMap(MODEL_NAME_VALUE, REQUESTS_PER_MINUTE); var taskSettings = getTaskSettingsMap(TOP_N, RETURN_DOCUMENTS_TRUE); var secretSettings = getSecretSettingsMap(API_KEY); @@ -619,11 +651,8 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var persistedConfig = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ID_VALUE, - RERANK, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel(INFERENCE_ID_VALUE, RERANK, MixedbreadService.NAME, persistedConfig.config(), persistedConfig.secrets()) ); assertThat(model, CoreMatchers.instanceOf(MixedbreadRerankModel.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaServiceTests.java index 3a57386b0271e..be25c3a5f9175 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/nvidia/NvidiaServiceTests.java @@ -59,7 +59,6 @@ import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; -import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionModel; import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionServiceSettings; @@ -147,7 +146,7 @@ public static TestConfiguration createTestConfiguration() { new CommonConfig(TEXT_EMBEDDING, SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION, RERANK)) { @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + protected NvidiaService createService(ThreadPool threadPool, HttpClientManager clientManager) { return NvidiaServiceTests.createService(threadPool, clientManager); } @@ -319,7 +318,7 @@ private static void assertRerankModel(Model model, boolean modelIncludesSecrets) assertThat(nvidiaModel.getTaskType(), is(RERANK)); } - public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + public static NvidiaService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); return new NvidiaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index a245e25fd4e70..6826068c0586e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -60,7 +60,6 @@ import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; -import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; @@ -166,7 +165,7 @@ public static TestConfiguration createTestConfiguration() { EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION) ) { @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + protected OpenAiService createService(ThreadPool threadPool, HttpClientManager clientManager) { return OpenAiServiceTests.createService(threadPool, clientManager); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index f97451bc92e4a..c323fca794ced 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -59,7 +59,6 @@ import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; -import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests; @@ -140,7 +139,7 @@ public static TestConfiguration createTestConfiguration() { new CommonConfig(TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION, RERANK)) { @Override - protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + protected OpenShiftAiService createService(ThreadPool threadPool, HttpClientManager clientManager) { return OpenShiftAiServiceTests.createService(threadPool, clientManager); } @@ -291,7 +290,7 @@ private static void assertChatCompletionModel(Model model, boolean modelIncludes assertThat(openShiftAiModel.getTaskType(), is(TaskType.CHAT_COMPLETION)); } - public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + public static OpenShiftAiService createService(ThreadPool threadPool, HttpClientManager clientManager) { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); return new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index 5471627cb6cb8..024c1d592dd58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.RerankingInferenceService; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.completion.ContentObjects; import org.elasticsearch.inference.completion.Message; import org.elasticsearch.rest.RestStatus; @@ -128,13 +129,21 @@ public void testParseRequestConfig() { })); } - public void testParsePersistedConfigWithSecrets() { - sageMakerService.parsePersistedConfigWithSecrets("modelId", TaskType.ANY, Map.of(), Map.of()); - verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(Map.of())); + public void testParsePersistedConfig_WithSecrets() { + sageMakerService.parsePersistedConfig( + new UnparsedModel("modelId", TaskType.ANY, SageMakerService.NAME, Map.of(), Map.of("key", "value")) + ); + verify(modelBuilder, only()).fromStorage( + eq("modelId"), + eq(TaskType.ANY), + eq(SageMakerService.NAME), + eq(Map.of()), + eq(Map.of("key", "value")) + ); } - public void testParsePersistedConfig() { - sageMakerService.parsePersistedConfig("modelId", TaskType.ANY, Map.of()); + public void testParsePersistedConfig_WithoutSecrets() { + sageMakerService.parsePersistedConfig(new UnparsedModel("modelId", TaskType.ANY, SageMakerService.NAME, Map.of(), null)); verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(null)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index de6cef166fbd6..6f3a3e023e752 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; @@ -325,7 +326,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap } } - public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAVoyageAIEmbeddingsModel() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), @@ -333,11 +334,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModel( getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -350,7 +354,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModel( } } - public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), @@ -359,11 +363,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -377,7 +384,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW } } - public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + public void testParsePersistedConfig_WithSecrets_CreatesAVoyageAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), @@ -385,11 +392,14 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -403,7 +413,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAVoyageAIEmbeddingsModelW } } - public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + public void testParsePersistedConfig_WithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("oldmodel"), @@ -413,11 +423,14 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.SPARSE_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ) ); @@ -429,7 +442,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), @@ -438,11 +451,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -455,7 +471,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { try (var service = createVoyageAIService()) { var secretSettingsMap = getSecretSettingsMap(API_KEY_VALUE); secretSettingsMap.put("extra_key", "value"); @@ -466,11 +482,14 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -483,7 +502,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createVoyageAIService()) { var persistedConfig = getPersistedConfigMap( VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), @@ -492,11 +511,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe ); persistedConfig.secrets().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -509,7 +531,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createVoyageAIService()) { var serviceSettingsMap = VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"); serviceSettingsMap.put("extra_key", "value"); @@ -520,13 +542,15 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); - MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model; @@ -537,7 +561,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe } } - public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createVoyageAIService()) { var taskSettingsMap = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH); taskSettingsMap.put("extra_key", "value"); @@ -548,11 +572,14 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa getSecretSettingsMap(API_KEY_VALUE) ); - var model = service.parsePersistedConfigWithSecrets( - INFERENCE_ENTITY_ID_VALUE, - TaskType.TEXT_EMBEDDING, - persistedConfig.config(), - persistedConfig.secrets() + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -572,7 +599,15 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModel() throws IO VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -589,10 +624,19 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunking var persistedConfig = getPersistedConfigMap( VoyageAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("model"), VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), - createRandomChunkingSettingsMap() + createRandomChunkingSettingsMap(), + null ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -612,7 +656,15 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWhenChunking VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -634,7 +686,15 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + () -> service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.SPARSE_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + new HashMap<>() + ) + ) ); assertThat(thrownException.getMessage(), containsString("Failed to parse stored model [id] for [voyageai] service")); @@ -652,7 +712,15 @@ public void testParsePersistedConfig_CreatesAVoyageAIEmbeddingsModelWithoutUrl() VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -671,7 +739,15 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() ); persistedConfig.config().put("extra_key", "value"); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -693,7 +769,15 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH) ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); @@ -715,7 +799,16 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( taskSettingsMap ); - var model = service.parsePersistedConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, persistedConfig.config()); + var model = service.parsePersistedConfig( + new UnparsedModel( + INFERENCE_ENTITY_ID_VALUE, + TaskType.TEXT_EMBEDDING, + VoyageAIService.NAME, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + MatcherAssert.assertThat(model, instanceOf(VoyageAIEmbeddingsModel.class)); var embeddingsModel = (VoyageAIEmbeddingsModel) model;