diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 7400890d66f08..2eae63e417c4a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -201,7 +201,7 @@ private static Map initDefaultEndpoints( new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, defaultDenseTextEmbeddingsSimilarity(), - null, + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, null, ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS ), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java index 34a8086119150..a68b23cf8cb5a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -46,6 +46,20 @@ public int rateLimitGroupingHash() { return Objects.hash(this.getServiceSettings().modelId()); } + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + ElasticInferenceServiceModel that = (ElasticInferenceServiceModel) o; + return Objects.equals(rateLimitServiceSettings, that.rateLimitServiceSettings) + && Objects.equals(elasticInferenceServiceComponents, that.elasticInferenceServiceComponents); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), rateLimitServiceSettings, elasticInferenceServiceComponents); + } + public RateLimitSettings rateLimitSettings() { return rateLimitServiceSettings.rateLimitSettings(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java index 40ff34b28e04a..851a4d3bd4dd1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; @@ -60,7 +61,7 @@ public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map>(); + + service.defaultConfigs(listener); + var models = listener.actionGet(TIMEOUT); + + var elasticInferenceServiceComponents = new ElasticInferenceServiceComponents(getUrl(webServer)); + + assertThat( + models, + containsInAnyOrder( + new ElasticInferenceServiceCompletionModel( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ), + new ElasticInferenceServiceSparseEmbeddingsModel( + DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ElasticInferenceServiceDenseTextEmbeddingsModel( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + defaultDenseTextEmbeddingsSimilarity(), + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + null, + null + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ElasticInferenceServiceRerankModel( + DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ) + ) + ); + } + } + public void testUnifiedCompletionError() { var e = assertThrows(UnifiedChatCompletionException.class, () -> testUnifiedStream(404, """ {