Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
defaultDenseTextEmbeddingsSimilarity(),
null,
DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
null,
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,20 @@ public int rateLimitGroupingHash() {
return Objects.hash(this.getServiceSettings().modelId());
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
ElasticInferenceServiceModel that = (ElasticInferenceServiceModel) o;
return Objects.equals(rateLimitServiceSettings, that.rateLimitServiceSettings)
&& Objects.equals(elasticInferenceServiceComponents, that.elasticInferenceServiceComponents);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), rateLimitServiceSettings, elasticInferenceServiceComponents);
}

public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -60,7 +61,7 @@ public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map<Strin
private final String modelId;
private final RateLimitSettings rateLimitSettings;

public ElasticInferenceServiceCompletionServiceSettings(String modelId, RateLimitSettings rateLimitSettings) {
public ElasticInferenceServiceCompletionServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) {
this.modelId = Objects.requireNonNull(modelId);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
@Nullable SimilarityMeasure similarity,
@Nullable Integer dimensions,
@Nullable Integer maxInputTokens,
RateLimitSettings rateLimitSettings
@Nullable RateLimitSettings rateLimitSettings
) {
this.modelId = modelId;
this.similarity = similarity;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.TestPlainActionFuture;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
Expand Down Expand Up @@ -48,6 +49,7 @@
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
Expand All @@ -61,9 +63,12 @@
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
Expand Down Expand Up @@ -93,8 +98,19 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_ELSER_ENDPOINT_ID_V2;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_ELSER_MODEL_ID_V2;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_RERANK_ENDPOINT_ID_V1;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_RERANK_MODEL_ID_V1;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -1312,8 +1328,8 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
".multilingual-embed-v1-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
service
Expand Down Expand Up @@ -1348,6 +1364,94 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
}
}

public void testDefaultConfigs_Returns_DefaultEndpointsModels() throws Exception {
String responseJson = """
{
"models": [
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
},
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed-v1",
"task_types": ["embed/text/dense"]
},
{
"model_name": "rerank-v1",
"task_types": ["rerank/text/text-similarity"]
}
]
}
""";

webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
ensureAuthorizationCallFinished(service);
var listener = new TestPlainActionFuture<List<Model>>();

service.defaultConfigs(listener);
var models = listener.actionGet(TIMEOUT);

var elasticInferenceServiceComponents = new ElasticInferenceServiceComponents(getUrl(webServer));

assertThat(
models,
containsInAnyOrder(
new ElasticInferenceServiceCompletionModel(
DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1,
TaskType.CHAT_COMPLETION,
ElasticInferenceService.NAME,
new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents
),
new ElasticInferenceServiceSparseEmbeddingsModel(
DEFAULT_ELSER_ENDPOINT_ID_V2,
TaskType.SPARSE_EMBEDDING,
ElasticInferenceService.NAME,
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents,
ChunkingSettingsBuilder.DEFAULT_SETTINGS
),
new ElasticInferenceServiceDenseTextEmbeddingsModel(
DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID,
TaskType.TEXT_EMBEDDING,
ElasticInferenceService.NAME,
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
defaultDenseTextEmbeddingsSimilarity(),
DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
null,
null
),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents,
ChunkingSettingsBuilder.DEFAULT_SETTINGS
),
new ElasticInferenceServiceRerankModel(
DEFAULT_RERANK_ENDPOINT_ID_V1,
TaskType.RERANK,
ElasticInferenceService.NAME,
new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents
)
)
);
}
}

public void testUnifiedCompletionError() {
var e = assertThrows(UnifiedChatCompletionException.class, () -> testUnifiedStream(404, """
{
Expand Down